| import torch |
| from safetensors.torch import save_file |
|
|
| weights = {} |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| for i in range(8): |
| w = [0.0] * 8 |
| for j in range(i+1, 8): |
| w[j] = 1.0 |
| if sum(w) == 0: |
| weights[f'h{i}.weight'] = torch.tensor([[0.0] * 8], dtype=torch.float32) |
| weights[f'h{i}.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
| else: |
| weights[f'h{i}.weight'] = torch.tensor([w], dtype=torch.float32) |
| weights[f'h{i}.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
|
|
| |
| for i in range(8): |
| weights[f'nh{i}.weight'] = torch.tensor([[-1.0]], dtype=torch.float32) |
| weights[f'nh{i}.bias'] = torch.tensor([0.0], dtype=torch.float32) |
|
|
| |
| for i in range(8): |
| weights[f'l{i}.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32) |
| weights[f'l{i}.bias'] = torch.tensor([-2.0], dtype=torch.float32) |
|
|
| |
| |
| |
| |
|
|
| weights['pos2.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) |
| weights['pos2.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
|
|
| weights['pos1.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) |
| weights['pos1.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
|
|
| weights['pos0.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) |
| weights['pos0.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
|
|
| |
| weights['valid.weight'] = torch.tensor([[1.0] * 8], dtype=torch.float32) |
| weights['valid.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
|
|
| save_file(weights, 'model.safetensors') |
|
|
| def leading_one(x): |
| if x == 0: |
| return 0, 0 |
| for i in range(7, -1, -1): |
| if (x >> i) & 1: |
| return i, 1 |
| return 0, 0 |
|
|
| print("Verifying 8-bit Leading One Detector...") |
| errors = 0 |
| for x in range(256): |
| pos, valid = leading_one(x) |
| bits = [(x >> i) & 1 for i in range(8)] |
| higher = [0] * 8 |
| for i in range(8): |
| for j in range(i+1, 8): |
| if bits[j]: |
| higher[i] = 1 |
| break |
| leaders = [bits[i] and not higher[i] for i in range(8)] |
|
|
| calc_pos2 = leaders[7] or leaders[6] or leaders[5] or leaders[4] |
| calc_pos1 = leaders[7] or leaders[6] or leaders[3] or leaders[2] |
| calc_pos0 = leaders[7] or leaders[5] or leaders[3] or leaders[1] |
| calc_pos = (calc_pos2 << 2) | (calc_pos1 << 1) | calc_pos0 |
| calc_valid = any(bits) |
|
|
| if calc_pos != pos or calc_valid != valid: |
| errors += 1 |
| if errors <= 5: |
| print(f"ERROR: x={x:08b}, got pos={calc_pos} valid={calc_valid}, expected pos={pos} valid={valid}") |
|
|
| if errors == 0: |
| print("All 256 test cases passed!") |
| else: |
| print(f"FAILED: {errors} errors") |
|
|
| print("\nExamples:") |
| for x in [0, 1, 2, 4, 8, 16, 32, 64, 128, 255, 0b10101010]: |
| pos, valid = leading_one(x) |
| print(f" {x:3d} ({x:08b}): position={pos}, valid={valid}") |
|
|
| mag = sum(t.abs().sum().item() for t in weights.values()) |
| print(f"\nMagnitude: {mag:.0f}") |
| print(f"Parameters: {sum(t.numel() for t in weights.values())}") |
|
|