threshold-leading-one-detect / create_safetensors.py
CharlesCNorton
8-bit leading one detector
57ca140
Raw
History Blame Contribute Delete
3.67 kB
import torch
from safetensors.torch import save_file
weights = {}
# 8-bit Leading One Detector (Find First Set from MSB)
# Input: x7,x6,x5,x4,x3,x2,x1,x0 (8 inputs)
# Output: pos2,pos1,pos0, valid (4 outputs)
#
# Returns binary position of highest-order 1 bit
# valid=0 if input is all zeros
# Priority logic: highest bit wins
# pos = 7 if x7, else 6 if x6, else 5 if x5, ...
# For each position, compute "this position is the leader"
# leader[i] = x[i] AND NOT(any higher bit set)
# Higher-than masks
# h7 = 0 (nothing higher)
# h6 = x7
# h5 = x7 OR x6
# etc.
for i in range(8):
w = [0.0] * 8
for j in range(i+1, 8):
w[j] = 1.0
if sum(w) == 0:
weights[f'h{i}.weight'] = torch.tensor([[0.0] * 8], dtype=torch.float32)
weights[f'h{i}.bias'] = torch.tensor([-1.0], dtype=torch.float32)
else:
weights[f'h{i}.weight'] = torch.tensor([w], dtype=torch.float32)
weights[f'h{i}.bias'] = torch.tensor([-1.0], dtype=torch.float32)
# NOT higher
for i in range(8):
weights[f'nh{i}.weight'] = torch.tensor([[-1.0]], dtype=torch.float32)
weights[f'nh{i}.bias'] = torch.tensor([0.0], dtype=torch.float32)
# leader[i] = x[i] AND NOT(higher)
for i in range(8):
weights[f'l{i}.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights[f'l{i}.bias'] = torch.tensor([-2.0], dtype=torch.float32)
# Encode position
# pos2 = leader[7] OR leader[6] OR leader[5] OR leader[4]
# pos1 = leader[7] OR leader[6] OR leader[3] OR leader[2]
# pos0 = leader[7] OR leader[5] OR leader[3] OR leader[1]
weights['pos2.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32)
weights['pos2.bias'] = torch.tensor([-1.0], dtype=torch.float32)
weights['pos1.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32)
weights['pos1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
weights['pos0.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32)
weights['pos0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
# valid = any bit set
weights['valid.weight'] = torch.tensor([[1.0] * 8], dtype=torch.float32)
weights['valid.bias'] = torch.tensor([-1.0], dtype=torch.float32)
save_file(weights, 'model.safetensors')
def leading_one(x):
if x == 0:
return 0, 0
for i in range(7, -1, -1):
if (x >> i) & 1:
return i, 1
return 0, 0
print("Verifying 8-bit Leading One Detector...")
errors = 0
for x in range(256):
pos, valid = leading_one(x)
bits = [(x >> i) & 1 for i in range(8)]
higher = [0] * 8
for i in range(8):
for j in range(i+1, 8):
if bits[j]:
higher[i] = 1
break
leaders = [bits[i] and not higher[i] for i in range(8)]
calc_pos2 = leaders[7] or leaders[6] or leaders[5] or leaders[4]
calc_pos1 = leaders[7] or leaders[6] or leaders[3] or leaders[2]
calc_pos0 = leaders[7] or leaders[5] or leaders[3] or leaders[1]
calc_pos = (calc_pos2 << 2) | (calc_pos1 << 1) | calc_pos0
calc_valid = any(bits)
if calc_pos != pos or calc_valid != valid:
errors += 1
if errors <= 5:
print(f"ERROR: x={x:08b}, got pos={calc_pos} valid={calc_valid}, expected pos={pos} valid={valid}")
if errors == 0:
print("All 256 test cases passed!")
else:
print(f"FAILED: {errors} errors")
print("\nExamples:")
for x in [0, 1, 2, 4, 8, 16, 32, 64, 128, 255, 0b10101010]:
pos, valid = leading_one(x)
print(f" {x:3d} ({x:08b}): position={pos}, valid={valid}")
mag = sum(t.abs().sum().item() for t in weights.values())
print(f"\nMagnitude: {mag:.0f}")
print(f"Parameters: {sum(t.numel() for t in weights.values())}")