import torch from safetensors.torch import save_file # Priority encoder: outputs binary index of highest-set input # Inputs: I3, I2, I1, I0 (I3 is highest priority) # Outputs: Y1, Y0 (binary encoding) weights = {} # Y1: fires when I3 or I2 is set (highest bit is 2 or 3) weights['y1.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32) weights['y1.bias'] = torch.tensor([-1.0], dtype=torch.float32) # Y0: fires when I3 is set, OR when I2 is NOT set but I1 is set # This gives output bit 0 for indices 1 and 3 weights['y0.weight'] = torch.tensor([[3.0, -2.0, 1.0, 0.0]], dtype=torch.float32) weights['y0.bias'] = torch.tensor([-1.0], dtype=torch.float32) save_file(weights, 'model.safetensors') def encode4to2(i3, i2, i1, i0): inp = torch.tensor([float(i3), float(i2), float(i1), float(i0)]) y1 = int((inp @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item()) y0 = int((inp @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item()) return y1, y0 print("Verifying 4to2encoder...") errors = 0 for val in range(16): i3, i2, i1, i0 = (val >> 3) & 1, (val >> 2) & 1, (val >> 1) & 1, val & 1 y1, y0 = encode4to2(i3, i2, i1, i0) # Expected: binary of highest set bit position if i3: expected = (1, 1) # 3 elif i2: expected = (1, 0) # 2 elif i1: expected = (0, 1) # 1 else: expected = (0, 0) # 0 or none if (y1, y0) != expected: errors += 1 print(f"ERROR: I={i3}{i2}{i1}{i0} -> ({y1},{y0}), expected {expected}") if errors == 0: print("All 16 test cases passed!") print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")