import torch from safetensors.torch import save_file weights = {} # Minifloat Normalizer (4-bit mantissa) # Shifts mantissa left until MSB is 1, adjusts exponent def add_neuron(name, w_list, bias): weights[f'{name}.weight'] = torch.tensor([w_list], dtype=torch.float32) weights[f'{name}.bias'] = torch.tensor([bias], dtype=torch.float32) # Input: M3,M2,M1,M0 (mantissa), E2,E1,E0 (exponent) # Detect leading zeros add_neuron('m3_is1', [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -1.0) # M3=1, no shift add_neuron('m2_lead', [-1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], 0.0) # M3=0,M2=1, shift 1 add_neuron('m1_lead', [-1.0, -1.0, 1.0, 0.0, 0.0, 0.0, 0.0], 1.0) # shift 2 add_neuron('m0_lead', [-1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0], 2.0) # shift 3 save_file(weights, 'model.safetensors') def normalize(m3, m2, m1, m0, e2, e1, e0): m = m3*8 + m2*4 + m1*2 + m0 e = e2*4 + e1*2 + e0 if m == 0: return 0, 0, 0, 0, 0, 0, 0 shift = 0 while (m & 8) == 0 and shift < 4: m = (m << 1) & 0xF shift += 1 e = max(0, e - shift) return (m>>3)&1, (m>>2)&1, (m>>1)&1, m&1, (e>>2)&1, (e>>1)&1, e&1 print("Verifying FP normalize...") errors = 0 for m in range(16): for e in range(8): m3, m2, m1, m0 = (m>>3)&1, (m>>2)&1, (m>>1)&1, m&1 e2, e1, e0 = (e>>2)&1, (e>>1)&1, e&1 result = normalize(m3, m2, m1, m0, e2, e1, e0) # Verify MSB is 1 (or value is 0) if m != 0 and result[0] != 1: errors += 1 if errors == 0: print("All test cases passed!") else: print(f"FAILED: {errors} errors") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"Magnitude: {mag:.0f}") print(f"Parameters: {sum(t.numel() for t in weights.values())}")