import torch from safetensors.torch import save_file weights = {} # 8-bit Leading One Detector (Find First Set from MSB) # Input: x7,x6,x5,x4,x3,x2,x1,x0 (8 inputs) # Output: pos2,pos1,pos0, valid (4 outputs) # # Returns binary position of highest-order 1 bit # valid=0 if input is all zeros # Priority logic: highest bit wins # pos = 7 if x7, else 6 if x6, else 5 if x5, ... # For each position, compute "this position is the leader" # leader[i] = x[i] AND NOT(any higher bit set) # Higher-than masks # h7 = 0 (nothing higher) # h6 = x7 # h5 = x7 OR x6 # etc. for i in range(8): w = [0.0] * 8 for j in range(i+1, 8): w[j] = 1.0 if sum(w) == 0: weights[f'h{i}.weight'] = torch.tensor([[0.0] * 8], dtype=torch.float32) weights[f'h{i}.bias'] = torch.tensor([-1.0], dtype=torch.float32) else: weights[f'h{i}.weight'] = torch.tensor([w], dtype=torch.float32) weights[f'h{i}.bias'] = torch.tensor([-1.0], dtype=torch.float32) # NOT higher for i in range(8): weights[f'nh{i}.weight'] = torch.tensor([[-1.0]], dtype=torch.float32) weights[f'nh{i}.bias'] = torch.tensor([0.0], dtype=torch.float32) # leader[i] = x[i] AND NOT(higher) for i in range(8): weights[f'l{i}.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32) weights[f'l{i}.bias'] = torch.tensor([-2.0], dtype=torch.float32) # Encode position # pos2 = leader[7] OR leader[6] OR leader[5] OR leader[4] # pos1 = leader[7] OR leader[6] OR leader[3] OR leader[2] # pos0 = leader[7] OR leader[5] OR leader[3] OR leader[1] weights['pos2.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) weights['pos2.bias'] = torch.tensor([-1.0], dtype=torch.float32) weights['pos1.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) weights['pos1.bias'] = torch.tensor([-1.0], dtype=torch.float32) weights['pos0.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) weights['pos0.bias'] = torch.tensor([-1.0], dtype=torch.float32) # valid = any bit set weights['valid.weight'] = torch.tensor([[1.0] * 8], dtype=torch.float32) weights['valid.bias'] = torch.tensor([-1.0], dtype=torch.float32) save_file(weights, 'model.safetensors') def leading_one(x): if x == 0: return 0, 0 for i in range(7, -1, -1): if (x >> i) & 1: return i, 1 return 0, 0 print("Verifying 8-bit Leading One Detector...") errors = 0 for x in range(256): pos, valid = leading_one(x) bits = [(x >> i) & 1 for i in range(8)] higher = [0] * 8 for i in range(8): for j in range(i+1, 8): if bits[j]: higher[i] = 1 break leaders = [bits[i] and not higher[i] for i in range(8)] calc_pos2 = leaders[7] or leaders[6] or leaders[5] or leaders[4] calc_pos1 = leaders[7] or leaders[6] or leaders[3] or leaders[2] calc_pos0 = leaders[7] or leaders[5] or leaders[3] or leaders[1] calc_pos = (calc_pos2 << 2) | (calc_pos1 << 1) | calc_pos0 calc_valid = any(bits) if calc_pos != pos or calc_valid != valid: errors += 1 if errors <= 5: print(f"ERROR: x={x:08b}, got pos={calc_pos} valid={calc_valid}, expected pos={pos} valid={valid}") if errors == 0: print("All 256 test cases passed!") else: print(f"FAILED: {errors} errors") print("\nExamples:") for x in [0, 1, 2, 4, 8, 16, 32, 64, 128, 255, 0b10101010]: pos, valid = leading_one(x) print(f" {x:3d} ({x:08b}): position={pos}, valid={valid}") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"\nMagnitude: {mag:.0f}") print(f"Parameters: {sum(t.numel() for t in weights.values())}")