import torch from safetensors.torch import save_file weights = {} # 2:1 Multiplexer # Inputs: d0, d1, s # Output: d0 if s=0, d1 if s=1 # # Formula: out = (d0 AND NOT s) OR (d1 AND s) # # Layer 1: # sel0: d0 AND NOT(s) - fires when d0=1 and s=0 # sel1: d1 AND s - fires when d1=1 and s=1 # # Layer 2: # or: OR(sel0, sel1) # sel0: d0 AND NOT(s) # Weights: [d0, d1, s] = [1, 0, -1], bias = -1 # d0=1, s=0: 1 + 0 - 0 - 1 = 0 >= 0 -> fires # d0=1, s=1: 1 + 0 - 1 - 1 = -1 < 0 -> doesn't fire weights['sel0.weight'] = torch.tensor([[1.0, 0.0, -1.0]], dtype=torch.float32) weights['sel0.bias'] = torch.tensor([-1.0], dtype=torch.float32) # sel1: d1 AND s # Weights: [d0, d1, s] = [0, 1, 1], bias = -2 # d1=1, s=1: 0 + 1 + 1 - 2 = 0 >= 0 -> fires # d1=1, s=0: 0 + 1 + 0 - 2 = -1 < 0 -> doesn't fire weights['sel1.weight'] = torch.tensor([[0.0, 1.0, 1.0]], dtype=torch.float32) weights['sel1.bias'] = torch.tensor([-2.0], dtype=torch.float32) # or: OR(sel0, sel1) weights['or.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32) weights['or.bias'] = torch.tensor([-1.0], dtype=torch.float32) save_file(weights, 'model.safetensors') # Verification def mux2(d0, d1, s): inp = torch.tensor([float(d0), float(d1), float(s)]) sel0 = int((inp @ weights['sel0.weight'].T + weights['sel0.bias'] >= 0).item()) sel1 = int((inp @ weights['sel1.weight'].T + weights['sel1.bias'] >= 0).item()) l1 = torch.tensor([float(sel0), float(sel1)]) return int((l1 @ weights['or.weight'].T + weights['or.bias'] >= 0).item()) print("Verifying MUX2...") errors = 0 for s in [0, 1]: for d0 in [0, 1]: for d1 in [0, 1]: result = mux2(d0, d1, s) expected = d1 if s else d0 if result != expected: errors += 1 print(f"ERROR: mux2({d0}, {d1}, {s}) = {result}, expected {expected}") if errors == 0: print("All 8 test cases passed!") else: print(f"FAILED: {errors} errors") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"Magnitude: {mag:.0f}") print(f"Parameters: {sum(t.numel() for t in weights.values())}")