import torch from safetensors.torch import save_file # Layer 1: N1 checks sum >= 1, N2 checks sum <= 1 # Layer 2: AND(N1, N2) = exactly 1 weights = { 'layer1.weight': torch.tensor([ [1.0, 1.0, 1.0, 1.0, 1.0], # N1: sum >= 1 [-1.0, -1.0, -1.0, -1.0, -1.0] # N2: sum <= 1 ], dtype=torch.float32), 'layer1.bias': torch.tensor([-1.0, 1.0], dtype=torch.float32), 'layer2.weight': torch.tensor([[1.0, 1.0]], dtype=torch.float32), 'layer2.bias': torch.tensor([-2.0], dtype=torch.float32) } save_file(weights, 'model.safetensors') def exactly1of5(a, b, c, d, e): inp = torch.tensor([float(a), float(b), float(c), float(d), float(e)]) l1 = (inp @ weights['layer1.weight'].T + weights['layer1.bias'] >= 0).float() out = (l1 @ weights['layer2.weight'].T + weights['layer2.bias'] >= 0).float() return int(out.item()) print("Verifying exactly1outof5...") errors = 0 for i in range(32): bits = [(i >> j) & 1 for j in range(5)] result = exactly1of5(*bits) expected = 1 if sum(bits) == 1 else 0 if result != expected: errors += 1 print(f"ERROR: {bits} -> {result}, expected {expected}") if errors == 0: print("All 32 test cases passed!") print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")