import torch from safetensors.torch import save_file weights = {} # 4-element Prefix Sum (Scan) # Computes running sum: Y[i] = sum(X[0:i+1]) def add_neuron(name, w_list, bias): weights[f'{name}.weight'] = torch.tensor([w_list], dtype=torch.float32) weights[f'{name}.bias'] = torch.tensor([bias], dtype=torch.float32) # Input: X3, X2, X1, X0 (4 single-bit values) # Output: Y3, Y2, Y1, Y0 (prefix sums, but as popcount thresholds) # For single-bit inputs, prefix sum is just running popcount # Y0 = X0 add_neuron('y0', [0.0, 0.0, 0.0, 1.0], -1.0) # Y1 = X0 + X1 >= 1 add_neuron('y1_ge1', [0.0, 0.0, 1.0, 1.0], -1.0) add_neuron('y1_ge2', [0.0, 0.0, 1.0, 1.0], -2.0) # Y2 = X0 + X1 + X2 >= 1 add_neuron('y2_ge1', [0.0, 1.0, 1.0, 1.0], -1.0) add_neuron('y2_ge2', [0.0, 1.0, 1.0, 1.0], -2.0) add_neuron('y2_ge3', [0.0, 1.0, 1.0, 1.0], -3.0) # Y3 = X0 + X1 + X2 + X3 >= 1 add_neuron('y3_ge1', [1.0, 1.0, 1.0, 1.0], -1.0) add_neuron('y3_ge2', [1.0, 1.0, 1.0, 1.0], -2.0) add_neuron('y3_ge3', [1.0, 1.0, 1.0, 1.0], -3.0) add_neuron('y3_ge4', [1.0, 1.0, 1.0, 1.0], -4.0) save_file(weights, 'model.safetensors') def prefix_sum(x3, x2, x1, x0): y0 = x0 y1 = x0 + x1 y2 = x0 + x1 + x2 y3 = x0 + x1 + x2 + x3 return y3, y2, y1, y0 print("Verifying prefix sum...") errors = 0 for v in range(16): x3, x2, x1, x0 = (v>>3)&1, (v>>2)&1, (v>>1)&1, v&1 y3, y2, y1, y0 = prefix_sum(x3, x2, x1, x0) if y0 != x0 or y1 != x0+x1 or y2 != x0+x1+x2 or y3 != x0+x1+x2+x3: errors += 1 if errors == 0: print("All 16 test cases passed!") else: print(f"FAILED: {errors} errors") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"Magnitude: {mag:.0f}") print(f"Parameters: {sum(t.numel() for t in weights.values())}")