""" Build tools for 8-bit Threshold Computer safetensors. Subcommands: python build.py memory - Generate 64KB memory circuits python build.py inputs - Add .inputs metadata tensors python build.py all - Run both (memory first, then inputs) ROUTING SCHEMA (formerly routing.json) ====================================== Routing info is now embedded in safetensors via .inputs tensors and signal registry metadata. INPUT SOURCE TYPES ------------------ 1. External input: "$input_name" - Named input to the circuit - Example: "$a", "$b", "$cin" 2. Gate output: "path.to.gate" - Output of another gate - Example: "ha1.sum", "layer1.or" 3. Bit extraction: "$input[i]" - Single bit from multi-bit input - Example: "$a[0]" (LSB), "$a[7]" (MSB for 8-bit) 4. Constant: "#0" or "#1" - Fixed value - Example: "#1" for carry-in in two's complement CIRCUIT TYPES ------------- Single-Layer Gates: .weight and .bias only "boolean.and": ["$a", "$b"] Two-Layer Gates (XOR, XNOR): layer1 + layer2 "boolean.xor.layer1.or": ["$a", "$b"] "boolean.xor.layer1.nand": ["$a", "$b"] "boolean.xor.layer2": ["layer1.or", "layer1.nand"] Hierarchical Circuits: nested sub-components "arithmetic.fulladder": { "ha1.sum.layer1.or": ["$a", "$b"], "ha1.carry": ["$a", "$b"], "ha2.sum.layer1.or": ["ha1.sum", "$cin"], "carry_or": ["ha1.carry", "ha2.carry"] } Bit-Indexed Circuits: multi-bit operations "arithmetic.ripplecarry8bit.fa0": ["$a[0]", "$b[0]", "#0"] "arithmetic.ripplecarry8bit.fa1": ["$a[1]", "$b[1]", "fa0.cout"] PACKED MEMORY CIRCUITS ---------------------- 64KB memory uses packed tensors (shapes for 16-bit address, 8-bit data): memory.addr_decode.weight: [65536, 16] memory.addr_decode.bias: [65536] memory.read.and.weight: [8, 65536, 2] memory.read.and.bias: [8, 65536] memory.read.or.weight: [8, 65536] memory.read.or.bias: [8] memory.write.sel.weight: [65536, 2] memory.write.sel.bias: [65536] memory.write.nsel.weight: [65536, 1] memory.write.nsel.bias: [65536] memory.write.and_old.weight: [65536, 8, 2] memory.write.and_old.bias: [65536, 8] memory.write.and_new.weight: [65536, 8, 2] memory.write.and_new.bias: [65536, 8] memory.write.or.weight: [65536, 8, 2] memory.write.or.bias: [65536, 8] Semantics: decode: sel[i] = H(sum(addr_bits * weight[i]) + bias[i]) read: bit[b] = H(sum(H([mem_bit, sel] * and_w) + and_b) * or_w + or_b) write: new = H(H([old, nsel] * and_old) + H([data, sel] * and_new) - 1) SIGNAL REGISTRY --------------- Signal IDs are stored in safetensors metadata as JSON: {"0": "#0", "1": "#1", "2": "$a", "3": "$b", ...} Each gate's .inputs tensor contains integer IDs referencing this registry. NAMING CONVENTIONS ------------------ - External inputs: $name or $name[bit] - Constants: #0, #1 - Internal gates: relative path from circuit root """ from __future__ import annotations import argparse import json import re from pathlib import Path from typing import Dict, Iterable, List, Set import torch from safetensors import safe_open from safetensors.torch import save_file MODEL_PATH = Path(__file__).resolve().parent / "neural_computer.safetensors" MANIFEST_PATH = Path(__file__).resolve().parent / "tensors.txt" DEFAULT_ADDR_BITS = 16 DEFAULT_MEM_BYTES = 1 << DEFAULT_ADDR_BITS MEMORY_PROFILES = { "full": 16, # 64KB - full CPU mode "reduced": 12, # 4KB - reduced CPU "scratchpad": 8, # 256 bytes - LLM scratchpad "registers": 4, # 16 bytes - LLM register file "none": 0, # Pure ALU, no memory } def load_tensors(path: Path) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} with safe_open(str(path), framework="pt") as f: for name in f.keys(): tensors[name] = f.get_tensor(name).clone() return tensors def get_all_gates(tensors: Dict[str, torch.Tensor]) -> Set[str]: gates = set() for name in tensors: if name.endswith('.weight'): gates.add(name[:-7]) return gates class SignalRegistry: def __init__(self): self.name_to_id: Dict[str, int] = {} self.id_to_name: Dict[int, str] = {} self.next_id = 0 self.register("#0") self.register("#1") def register(self, name: str) -> int: if name not in self.name_to_id: self.name_to_id[name] = self.next_id self.id_to_name[self.next_id] = name self.next_id += 1 return self.name_to_id[name] def get_id(self, name: str) -> int: return self.name_to_id.get(name, -1) def to_metadata(self) -> str: return json.dumps(self.id_to_name) def add_gate(tensors: Dict[str, torch.Tensor], name: str, weight: Iterable[float], bias: Iterable[float]) -> None: w_key = f"{name}.weight" b_key = f"{name}.bias" if w_key in tensors or b_key in tensors: raise ValueError(f"Gate already exists: {name}") tensors[w_key] = torch.tensor(list(weight), dtype=torch.float32) tensors[b_key] = torch.tensor(list(bias), dtype=torch.float32) def drop_prefixes(tensors: Dict[str, torch.Tensor], prefixes: List[str]) -> None: for key in list(tensors.keys()): if any(key.startswith(prefix) for prefix in prefixes): del tensors[key] def add_decoder(tensors: Dict[str, torch.Tensor], addr_bits: int, mem_bytes: int) -> None: weights = torch.empty((mem_bytes, addr_bits), dtype=torch.float32) bias = torch.empty((mem_bytes,), dtype=torch.float32) for addr in range(mem_bytes): bits = [(addr >> (addr_bits - 1 - i)) & 1 for i in range(addr_bits)] weights[addr] = torch.tensor([1.0 if bit == 1 else -1.0 for bit in bits], dtype=torch.float32) bias[addr] = -float(sum(bits)) tensors["memory.addr_decode.weight"] = weights tensors["memory.addr_decode.bias"] = bias def add_memory_read_mux(tensors: Dict[str, torch.Tensor], mem_bytes: int) -> None: and_weight = torch.ones((8, mem_bytes, 2), dtype=torch.float32) and_bias = torch.full((8, mem_bytes), -2.0, dtype=torch.float32) or_weight = torch.ones((8, mem_bytes), dtype=torch.float32) or_bias = torch.full((8,), -1.0, dtype=torch.float32) tensors["memory.read.and.weight"] = and_weight tensors["memory.read.and.bias"] = and_bias tensors["memory.read.or.weight"] = or_weight tensors["memory.read.or.bias"] = or_bias def add_memory_write_cells(tensors: Dict[str, torch.Tensor], mem_bytes: int) -> None: sel_weight = torch.ones((mem_bytes, 2), dtype=torch.float32) sel_bias = torch.full((mem_bytes,), -2.0, dtype=torch.float32) nsel_weight = torch.full((mem_bytes, 1), -1.0, dtype=torch.float32) nsel_bias = torch.zeros((mem_bytes,), dtype=torch.float32) and_old_weight = torch.ones((mem_bytes, 8, 2), dtype=torch.float32) and_old_bias = torch.full((mem_bytes, 8), -2.0, dtype=torch.float32) and_new_weight = torch.ones((mem_bytes, 8, 2), dtype=torch.float32) and_new_bias = torch.full((mem_bytes, 8), -2.0, dtype=torch.float32) or_weight = torch.ones((mem_bytes, 8, 2), dtype=torch.float32) or_bias = torch.full((mem_bytes, 8), -1.0, dtype=torch.float32) tensors["memory.write.sel.weight"] = sel_weight tensors["memory.write.sel.bias"] = sel_bias tensors["memory.write.nsel.weight"] = nsel_weight tensors["memory.write.nsel.bias"] = nsel_bias tensors["memory.write.and_old.weight"] = and_old_weight tensors["memory.write.and_old.bias"] = and_old_bias tensors["memory.write.and_new.weight"] = and_new_weight tensors["memory.write.and_new.bias"] = and_new_bias tensors["memory.write.or.weight"] = or_weight tensors["memory.write.or.bias"] = or_bias def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor], addr_bits: int) -> None: for bit in range(16): add_gate(tensors, f"control.fetch.ir.bit{bit}", [1.0], [-1.0]) for bit in range(8): add_gate(tensors, f"control.load.bit{bit}", [1.0], [-1.0]) add_gate(tensors, f"control.store.bit{bit}", [1.0], [-1.0]) for bit in range(addr_bits): add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0]) def add_full_adder(tensors: Dict[str, torch.Tensor], prefix: str) -> None: """Add a single full adder at the given prefix. Full adder structure: - ha1: first half adder (A XOR B for sum, A AND B for carry) - ha2: second half adder (ha1.sum XOR Cin for sum, ha1.sum AND Cin for carry) - carry_or: OR of ha1.carry and ha2.carry for final carry out """ # XOR for ha1.sum (2-layer: OR + NAND -> AND) add_gate(tensors, f"{prefix}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0]) add_gate(tensors, f"{prefix}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0]) add_gate(tensors, f"{prefix}.ha1.sum.layer2", [1.0, 1.0], [-2.0]) # AND for ha1.carry add_gate(tensors, f"{prefix}.ha1.carry", [1.0, 1.0], [-2.0]) # XOR for ha2.sum add_gate(tensors, f"{prefix}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0]) add_gate(tensors, f"{prefix}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0]) add_gate(tensors, f"{prefix}.ha2.sum.layer2", [1.0, 1.0], [-2.0]) # AND for ha2.carry add_gate(tensors, f"{prefix}.ha2.carry", [1.0, 1.0], [-2.0]) # OR for final carry add_gate(tensors, f"{prefix}.carry_or", [1.0, 1.0], [-1.0]) def add_expr_add_mul(tensors: Dict[str, torch.Tensor]) -> None: """Add expression circuit for A + B × C (order of operations). Computes A + (B × C) where multiplication has higher precedence. Structure: - Stage 1: Multiply B × C using shift-add algorithm - 8 mask stages: mask[i] = B AND C[i] (8 AND gates each, shifted) - 7 accumulator adders to sum masked values - Stage 2: Add A to multiplication result (8-bit ripple carry) Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each) Output: 8-bit result of A + (B × C), wrapping on overflow Total: 64 AND gates + 7×8 full adders (mul) + 8 full adders (add) = ~640 gates """ prefix = "arithmetic.expr_add_mul" # Stage 1: Multiply B × C using shift-add # For each bit i of C, we AND all bits of B with C[i] # This creates partial products that are shifted by i positions # Mask AND gates: mask[stage][bit] = B[bit] AND C[stage] # These compute B & (C[i] ? 0xFF : 0x00) for each bit of C for stage in range(8): for bit in range(8): add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0]) # Accumulator adders for shift-add multiplication # Stage 0: acc = mask0 (no adder needed, just the masked value) # Stage 1-7: acc = acc + (mask[i] << i) # We need to handle the shifting by connecting different bit positions # For proper shift-add, we need adders that accumulate partial products # Each stage adds a shifted partial product to the accumulator # Using 16-bit internal accumulator, output low 8 bits # Simplified approach: chain of 8-bit adders with proper bit alignment # acc_stage[i] = acc_stage[i-1] + (mask[i] << i) # We keep only low 8 bits at each stage for 8-bit result for stage in range(1, 8): # 7 accumulator adders for bit in range(8): add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}") # Stage 2: Add A to multiplication result for bit in range(8): add_full_adder(tensors, f"{prefix}.add.fa{bit}") def add_expr_paren_add_mul(tensors: Dict[str, torch.Tensor]) -> None: """Add expression circuit for (A + B) × C (parenthetical override). Computes (A + B) × C where parentheses override normal precedence. Addition happens first, then multiplication. Structure: - Stage 1: Add A + B (8-bit ripple carry adder) - Stage 2: Multiply sum × C using shift-add algorithm - 8 mask stages: mask[i] = sum AND C[i] (8 AND gates each) - 7 accumulator adders to sum shifted masked values Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each) Output: 8-bit result of (A + B) × C, wrapping on overflow Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates """ prefix = "arithmetic.expr_paren_add_mul" # Stage 1: Add A + B for bit in range(8): add_full_adder(tensors, f"{prefix}.add.fa{bit}") # Stage 2: Multiply sum × C using shift-add # Mask AND gates: mask[stage][bit] = sum[bit] AND C[stage] for stage in range(8): for bit in range(8): add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0]) # Accumulator adders for shift-add multiplication for stage in range(1, 8): # 7 accumulator adders for bit in range(8): add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}") def add_expr_paren(tensors: Dict[str, torch.Tensor]) -> None: """Add expression circuit for (A + B) × C (parenthetical grouping). Computes (A + B) × C where addition happens first due to parentheses. Structure: - Stage 1: Add A + B (8-bit ripple carry) - Stage 2: Multiply sum × C using shift-add algorithm - 8 mask stages: mask[i] = sum AND C[i] (8 AND gates each) - 7 accumulator adders to sum shifted masked values Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each) Output: 8-bit result of (A + B) × C, wrapping on overflow Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates """ prefix = "arithmetic.expr_paren" # Stage 1: Add A + B for bit in range(8): add_full_adder(tensors, f"{prefix}.add.fa{bit}") # Stage 2: Multiply sum × C using shift-add # Mask AND gates: mask[stage][bit] = sum[bit] AND C[stage] for stage in range(8): for bit in range(8): add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0]) # Accumulator adders for shift-add multiplication for stage in range(1, 8): # 7 accumulator adders for bit in range(8): add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}") def add_add3(tensors: Dict[str, torch.Tensor]) -> None: """Add 3-operand 8-bit adder circuit. Computes A + B + C using two chained ripple-carry stages: - Stage 1: temp = A + B (8 full adders) - Stage 2: result = temp + C (8 full adders) Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first) Outputs: stage2.fa0-7.ha2.sum.layer2 (result bits), stage2.fa7.carry_or (overflow) Total: 16 full adders = 144 gates """ # Stage 1: A + B -> temp for bit in range(8): add_full_adder(tensors, f"arithmetic.add3_8bit.stage1.fa{bit}") # Stage 2: temp + C -> result for bit in range(8): add_full_adder(tensors, f"arithmetic.add3_8bit.stage2.fa{bit}") def add_shl_shr(tensors: Dict[str, torch.Tensor]) -> None: """Add SHL (shift left) and SHR (shift right) circuits. Identity gate: w=2, b=-1 -> H(x*2 - 1) = x for x in {0,1} Zero gate: w=0, b=-1 -> H(-1) = 0 SHL (MSB-first): out[i] = in[i+1] for i<7, out[7] = 0 SHR (MSB-first): out[0] = 0, out[i] = in[i-1] for i>0 """ for bit in range(8): if bit < 7: add_gate(tensors, f"alu.alu8bit.shl.bit{bit}", [2.0], [-1.0]) else: add_gate(tensors, f"alu.alu8bit.shl.bit{bit}", [0.0], [-1.0]) for bit in range(8): if bit > 0: add_gate(tensors, f"alu.alu8bit.shr.bit{bit}", [2.0], [-1.0]) else: add_gate(tensors, f"alu.alu8bit.shr.bit{bit}", [0.0], [-1.0]) def add_mul(tensors: Dict[str, torch.Tensor]) -> None: """Add 8-bit multiplication circuit. Produces low 8 bits of the 16-bit result. Structure: - 64 AND gates for partial products P[i][j] = A[i] AND B[j] - Uses existing ripple-carry adder components for summation The multiply method in ThresholdALU computes: 1. Partial products via these AND gates 2. Shift-add accumulation via existing 8-bit adder """ # AND gates for partial products: P[i][j] = A[i] AND B[j] # These compute whether bit i of A and bit j of B are both 1 for i in range(8): for j in range(8): add_gate(tensors, f"alu.alu8bit.mul.pp.a{i}b{j}", [1.0, 1.0], [-2.0]) def add_div(tensors: Dict[str, torch.Tensor]) -> None: """Add 8-bit division circuit. Produces quotient (8 bits) and remainder (8 bits). Uses restoring division algorithm: - 8 iterations, each producing one quotient bit - Each iteration: compare, conditionally subtract, shift Structure: - 8 comparison gates (one per iteration) - 8 conditional subtraction stages - Uses existing comparator and subtractor components """ # Comparison gates: check if (remainder << 1 | next_bit) >= divisor for stage in range(8): add_gate(tensors, f"alu.alu8bit.div.stage{stage}.cmp", [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0, -128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0], [0.0]) # Conditional mux gates: select (rem - div) or rem based on comparison for stage in range(8): for bit in range(8): # NOT for inverting comparison result add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.not_sel", [-1.0], [0.0]) # AND gates for mux add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.and_a", [1.0, 1.0], [-2.0]) add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.and_b", [1.0, 1.0], [-2.0]) # OR gate for mux output add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.or", [1.0, 1.0], [-1.0]) def add_inc_dec(tensors: Dict[str, torch.Tensor]) -> None: """Add INC and DEC circuits. INC: A + 1 using half adders with carry chain DEC: A - 1 using borrow chain (A + 255, two's complement of 1) For INC, we add 1 to the LSB and propagate carry. For DEC, we add 0xFF (two's complement of 1) or use borrow logic. """ # INC: half adder chain starting with carry_in = 1 # bit 7 (LSB): XOR with 1, carry = bit[7] # bit 6: XOR with carry, new_carry = bit[6] AND old_carry # ... for bit in range(8): # XOR for sum (two-layer) add_gate(tensors, f"alu.alu8bit.inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) add_gate(tensors, f"alu.alu8bit.inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) add_gate(tensors, f"alu.alu8bit.inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) # AND for carry propagation add_gate(tensors, f"alu.alu8bit.inc.bit{bit}.carry", [1.0, 1.0], [-2.0]) # DEC: similar but with borrow logic # Equivalent to adding 0xFF with carry_in = 0 # Or: NOT each bit, propagate borrow for bit in range(8): # XOR for difference add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) # Borrow: NOT(A) AND borrow_in, equivalent to (NOT A) when borrow_in=1 add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.not_a", [-1.0], [0.0]) add_gate(tensors, f"alu.alu8bit.dec.bit{bit}.borrow", [1.0, 1.0], [-2.0]) def add_neg(tensors: Dict[str, torch.Tensor]) -> None: """Add NEG circuit (two's complement negation). NEG(A) = NOT(A) + 1 = ~A + 1 Structure: NOT gates followed by INC-style adder. """ for bit in range(8): # NOT gate for each bit add_gate(tensors, f"alu.alu8bit.neg.not.bit{bit}", [-1.0], [0.0]) # Then add 1 using half adder chain add_gate(tensors, f"alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) add_gate(tensors, f"alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) add_gate(tensors, f"alu.alu8bit.neg.inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) add_gate(tensors, f"alu.alu8bit.neg.inc.bit{bit}.carry", [1.0, 1.0], [-2.0]) def add_rol_ror(tensors: Dict[str, torch.Tensor]) -> None: """Add ROL and ROR circuits (rotate left/right). ROL: out[i] = in[i+1] for i<7, out[7] = in[0] (MSB wraps to LSB) ROR: out[0] = in[7], out[i] = in[i-1] for i>0 (LSB wraps to MSB) Identity gates with circular wiring. """ # ROL: rotate left (toward MSB) for bit in range(8): src = (bit + 1) % 8 # Circular: bit 7 gets bit 0 add_gate(tensors, f"alu.alu8bit.rol.bit{bit}", [2.0], [-1.0]) # ROR: rotate right (toward LSB) for bit in range(8): src = (bit - 1) % 8 # Circular: bit 0 gets bit 7 add_gate(tensors, f"alu.alu8bit.ror.bit{bit}", [2.0], [-1.0]) def add_stack_ops(tensors: Dict[str, torch.Tensor]) -> None: """Add RET, PUSH, POP circuit components. These are higher-level operations that use memory read/write. We create the control logic gates. RET: Pop return address from stack, jump to it PUSH: Decrement SP, write value to [SP] POP: Read value from [SP], increment SP """ # SP decrement for PUSH (16-bit) for bit in range(16): add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) add_gate(tensors, f"control.push.sp_dec.bit{bit}.borrow", [1.0, 1.0], [-2.0]) # SP increment for POP (16-bit) for bit in range(16): add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0]) add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0]) add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0]) add_gate(tensors, f"control.pop.sp_inc.bit{bit}.carry", [1.0, 1.0], [-2.0]) # RET uses POP twice (for 16-bit address) then jumps # Buffer gates for return address for bit in range(16): add_gate(tensors, f"control.ret.addr.bit{bit}", [2.0], [-1.0]) def add_barrel_shifter(tensors: Dict[str, torch.Tensor]) -> None: """Add barrel shifter circuit. Shifts input by 0-7 positions based on 3-bit shift amount. Uses layers of 2:1 muxes controlled by shift amount bits. Layer 0: shift by 0 or 1 (controlled by shift[2], LSB) Layer 1: shift by 0 or 2 (controlled by shift[1]) Layer 2: shift by 0 or 4 (controlled by shift[0], MSB) """ # 3 layers of muxes, 8 bits each for layer in range(3): shift_amount = 1 << (2 - layer) # 4, 2, 1 for layers 0, 1, 2 for bit in range(8): # 2:1 mux: if sel then shifted else original # NOT for inverting select add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.not_sel", [-1.0], [0.0]) # AND gates add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.and_a", [1.0, 1.0], [-2.0]) add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.and_b", [1.0, 1.0], [-2.0]) # OR gate add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.or", [1.0, 1.0], [-1.0]) def add_priority_encoder(tensors: Dict[str, torch.Tensor]) -> None: """Add priority encoder circuit. Finds the position of the highest set bit (0-7). Output is 3-bit index + valid flag. Uses cascaded comparisons: check bit 7 first, then 6, etc. """ # Check each bit position (8 OR gates to detect any bit set at or above position) for pos in range(8): # OR of bits pos through 7 num_inputs = 8 - pos weights = [1.0] * num_inputs add_gate(tensors, f"combinational.priorityencoder.any_ge{pos}", weights, [-1.0]) # Priority logic: pos N is highest if bit N is set AND no higher bit is set for pos in range(8): # bit[pos] AND NOT(any bit > pos) add_gate(tensors, f"combinational.priorityencoder.is_highest{pos}.not_higher", [-1.0], [0.0]) add_gate(tensors, f"combinational.priorityencoder.is_highest{pos}.and", [1.0, 1.0], [-2.0]) # Encode position to 3-bit output # out[0] (LSB): positions 1,3,5,7 # out[1]: positions 2,3,6,7 # out[2] (MSB): positions 4,5,6,7 for out_bit in range(3): weights = [] for pos in range(8): if (pos >> out_bit) & 1: weights.append(1.0) if weights: add_gate(tensors, f"combinational.priorityencoder.out{out_bit}", weights, [-1.0]) # Valid flag: any bit set add_gate(tensors, f"combinational.priorityencoder.valid", [1.0] * 8, [-1.0]) def add_comparators(tensors: Dict[str, torch.Tensor]) -> None: """Add 8-bit comparator circuits (GT, LT, GE, LE, EQ). Each comparator takes 16 inputs (8 bits from A, 8 bits from B) in MSB-first order. Uses weighted sum comparison on the binary representation. For unsigned comparison of A vs B: - Assign positional weights: bit i has weight 2^(7-i) - A > B: sum(a_i * w_i) > sum(b_i * w_i) - This becomes: sum(a_i * w_i - b_i * w_i) > 0 - Or: sum((a_i - b_i) * w_i) > 0 Threshold gate: H(sum(x_i * w_i) + b) = 1 if sum >= -b For A > B: weights = [128, 64, 32, 16, 8, 4, 2, 1, -128, -64, -32, -16, -8, -4, -2, -1] bias = -1 (strictly greater, so need sum >= 1) For A >= B: bias = 0 (sum >= 0) For A < B: flip weights, bias = -1 For A <= B: flip weights, bias = 0 For A == B: need A >= B AND A <= B (two-layer) """ pos_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0] neg_weights = [-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0] gt_weights = pos_weights + neg_weights lt_weights = neg_weights + pos_weights add_gate(tensors, "arithmetic.greaterthan8bit", gt_weights, [-1.0]) add_gate(tensors, "arithmetic.greaterorequal8bit", gt_weights, [0.0]) add_gate(tensors, "arithmetic.lessthan8bit", lt_weights, [-1.0]) add_gate(tensors, "arithmetic.lessorequal8bit", lt_weights, [0.0]) add_gate(tensors, "arithmetic.equality8bit.layer1.geq", gt_weights, [0.0]) add_gate(tensors, "arithmetic.equality8bit.layer1.leq", lt_weights, [0.0]) add_gate(tensors, "arithmetic.equality8bit.layer2", [1.0, 1.0], [-2.0]) def update_manifest(tensors: Dict[str, torch.Tensor], addr_bits: int, mem_bytes: int) -> None: tensors["manifest.memory_bytes"] = torch.tensor([float(mem_bytes)], dtype=torch.float32) tensors["manifest.pc_width"] = torch.tensor([float(addr_bits)], dtype=torch.float32) tensors["manifest.version"] = torch.tensor([3.0], dtype=torch.float32) def write_manifest(path: Path, tensors: Dict[str, torch.Tensor]) -> None: lines: List[str] = [] lines.append("# Tensor Manifest") lines.append(f"# Total: {len(tensors)} tensors") for name in sorted(tensors.keys()): t = tensors[name] values = ", ".join(f"{v:.1f}" for v in t.flatten().tolist()) lines.append(f"{name}: shape={list(t.shape)}, values=[{values}]") path.write_text("\n".join(lines) + "\n", encoding="utf-8") def infer_boolean_inputs(gate: str, reg: SignalRegistry) -> List[int]: if gate == 'boolean.not': return [reg.register("$x")] if gate in ['boolean.and', 'boolean.or', 'boolean.nand', 'boolean.nor', 'boolean.implies']: return [reg.register("$a"), reg.register("$b")] if '.layer1.neuron1' in gate or '.layer1.neuron2' in gate or '.layer1.or' in gate or '.layer1.nand' in gate: return [reg.register("$a"), reg.register("$b")] if '.layer2' in gate: parent = gate.rsplit('.layer2', 1)[0] if '.layer1.neuron1' in parent or 'xor' in parent or 'xnor' in parent or 'biimplies' in parent: parent = parent.rsplit('.layer1', 1)[0] if '.layer1' in parent else parent return [reg.register(f"{parent}.layer1.or"), reg.register(f"{parent}.layer1.nand")] return [] def infer_halfadder_inputs(gate: str, prefix: str, reg: SignalRegistry) -> List[int]: a = reg.register(f"{prefix}.$a") b = reg.register(f"{prefix}.$b") if '.sum.layer1' in gate: return [a, b] if '.sum.layer2' in gate: return [reg.register(f"{prefix}.sum.layer1.or"), reg.register(f"{prefix}.sum.layer1.nand")] if '.carry' in gate and '.layer' not in gate: return [a, b] return [a, b] def infer_fulladder_inputs(gate: str, prefix: str, reg: SignalRegistry) -> List[int]: a = reg.register(f"{prefix}.$a") b = reg.register(f"{prefix}.$b") cin = reg.register(f"{prefix}.$cin") if '.ha1.sum.layer1' in gate: return [a, b] if '.ha1.sum.layer2' in gate: return [reg.register(f"{prefix}.ha1.sum.layer1.or"), reg.register(f"{prefix}.ha1.sum.layer1.nand")] if '.ha1.carry' in gate and '.layer' not in gate: return [a, b] if '.ha2.sum.layer1' in gate: return [reg.register(f"{prefix}.ha1.sum.layer2"), cin] if '.ha2.sum.layer2' in gate: return [reg.register(f"{prefix}.ha2.sum.layer1.or"), reg.register(f"{prefix}.ha2.sum.layer1.nand")] if '.ha2.carry' in gate and '.layer' not in gate: return [reg.register(f"{prefix}.ha1.sum.layer2"), cin] if '.carry_or' in gate: return [reg.register(f"{prefix}.ha1.carry"), reg.register(f"{prefix}.ha2.carry")] return [] def infer_ripplecarry_inputs(gate: str, prefix: str, bits: int, reg: SignalRegistry) -> List[int]: for i in range(bits): reg.register(f"{prefix}.$a[{i}]") reg.register(f"{prefix}.$b[{i}]") m = re.search(r'\.fa(\d+)\.', gate) if not m: return [] bit = int(m.group(1)) a_bit = reg.get_id(f"{prefix}.$a[{bit}]") b_bit = reg.get_id(f"{prefix}.$b[{bit}]") cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.fa{bit-1}.carry_or") fa_prefix = f"{prefix}.fa{bit}" if '.ha1.sum.layer1' in gate: return [a_bit, b_bit] if '.ha1.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] if '.ha1.carry' in gate and '.layer' not in gate: return [a_bit, b_bit] if '.ha2.sum.layer1' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.ha2.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] if '.ha2.carry' in gate and '.layer' not in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.carry_or' in gate: return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] return [] def infer_expr_add_mul_inputs(gate: str, reg: SignalRegistry) -> List[int]: """Infer inputs for A + B × C expression circuit (order of operations). Circuit structure: - Mask stage: mask.s[stage].b[bit] = B[bit] AND C[stage] - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage) - Final add: result = A + acc.s7 Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB) - $x[7] = bit 0 (LSB), $x[0] = bit 7 (MSB) """ prefix = "arithmetic.expr_add_mul" # Register all inputs for i in range(8): reg.register(f"$a[{i}]") reg.register(f"$b[{i}]") reg.register(f"$c[{i}]") # Mask AND gates: mask.s[stage].b[bit] = B[bit] AND C[stage] if '.mul.mask.' in gate: m = re.search(r'\.s(\d+)\.b(\d+)', gate) if m: stage = int(m.group(1)) bit = int(m.group(2)) # MSB-first: $b[7-bit] is bit position 'bit', $c[7-stage] is stage position 'stage' b_input = reg.get_id(f"$b[{7-bit}]") c_input = reg.get_id(f"$c[{7-stage}]") return [b_input, c_input] return [] # Accumulator adders: acc.s[stage].fa[bit] if '.mul.acc.' in gate: m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate) if not m: return [] stage = int(m.group(1)) # 1-7 bit = int(m.group(2)) # 0-7 # A input: previous stage output if stage == 1: # First accumulator: A = mask.s0.b[bit] (AND gate output) a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}") else: # Later stages: A = previous accumulator sum a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2") # B input: (mask.s[stage] << stage)[bit] # Shift left by 'stage' positions means: # - bit positions 0 to stage-1 get 0 # - bit position 'bit' gets mask.s[stage].b[bit-stage] if bit < stage: b_input = reg.get_id("#0") else: b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}") # Carry input if bit == 0: cin = reg.get_id("#0") else: cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or") fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}" if '.ha1.sum.layer1' in gate: return [a_input, b_input] if '.ha1.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] if '.ha1.carry' in gate and '.layer' not in gate: return [a_input, b_input] if '.ha2.sum.layer1' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.ha2.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] if '.ha2.carry' in gate and '.layer' not in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.carry_or' in gate: return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] return [] # Final add stage: A + mul_result if '.add.fa' in gate: m = re.search(r'\.fa(\d+)\.', gate) if not m: return [] bit = int(m.group(1)) # A input: $a[7-bit] (MSB-first to positional bit) a_input = reg.get_id(f"$a[{7-bit}]") # B input: multiplication result = acc.s7.fa[bit] sum output b_input = reg.register(f"{prefix}.mul.acc.s7.fa{bit}.ha2.sum.layer2") # Carry input if bit == 0: cin = reg.get_id("#0") else: cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or") fa_prefix = f"{prefix}.add.fa{bit}" if '.ha1.sum.layer1' in gate: return [a_input, b_input] if '.ha1.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] if '.ha1.carry' in gate and '.layer' not in gate: return [a_input, b_input] if '.ha2.sum.layer1' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.ha2.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] if '.ha2.carry' in gate and '.layer' not in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.carry_or' in gate: return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] return [] return [] def infer_expr_paren_add_mul_inputs(gate: str, reg: SignalRegistry) -> List[int]: """Infer inputs for (A + B) × C expression circuit (parenthetical override). Circuit structure: - Add stage: sum = A + B - Mask stage: mask.s[stage].b[bit] = sum[bit] AND C[stage] - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage) Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB) """ prefix = "arithmetic.expr_paren_add_mul" # Register all inputs for i in range(8): reg.register(f"$a[{i}]") reg.register(f"$b[{i}]") reg.register(f"$c[{i}]") # Add stage: A + B if '.add.fa' in gate and '.mul.' not in gate: m = re.search(r'\.fa(\d+)\.', gate) if not m: return [] bit = int(m.group(1)) # A input: $a[7-bit], B input: $b[7-bit] a_input = reg.get_id(f"$a[{7-bit}]") b_input = reg.get_id(f"$b[{7-bit}]") # Carry input if bit == 0: cin = reg.get_id("#0") else: cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or") fa_prefix = f"{prefix}.add.fa{bit}" if '.ha1.sum.layer1' in gate: return [a_input, b_input] if '.ha1.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] if '.ha1.carry' in gate and '.layer' not in gate: return [a_input, b_input] if '.ha2.sum.layer1' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.ha2.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] if '.ha2.carry' in gate and '.layer' not in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.carry_or' in gate: return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] return [] # Mask AND gates: mask.s[stage].b[bit] = sum[bit] AND C[stage] if '.mul.mask.' in gate: m = re.search(r'\.s(\d+)\.b(\d+)', gate) if m: stage = int(m.group(1)) bit = int(m.group(2)) # sum[bit] comes from add.fa[bit].ha2.sum.layer2 sum_bit = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2") # C[stage] in MSB-first c_input = reg.get_id(f"$c[{7-stage}]") return [sum_bit, c_input] return [] # Accumulator adders: acc.s[stage].fa[bit] if '.mul.acc.' in gate: m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate) if not m: return [] stage = int(m.group(1)) # 1-7 bit = int(m.group(2)) # 0-7 # A input: previous stage output if stage == 1: # First accumulator: A = mask.s0.b[bit] (AND gate output) a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}") else: # Later stages: A = previous accumulator sum a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2") # B input: (mask.s[stage] << stage)[bit] if bit < stage: b_input = reg.get_id("#0") else: b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}") # Carry input if bit == 0: cin = reg.get_id("#0") else: cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or") fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}" if '.ha1.sum.layer1' in gate: return [a_input, b_input] if '.ha1.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] if '.ha1.carry' in gate and '.layer' not in gate: return [a_input, b_input] if '.ha2.sum.layer1' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.ha2.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] if '.ha2.carry' in gate and '.layer' not in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.carry_or' in gate: return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] return [] return [] def infer_expr_paren_inputs(gate: str, reg: SignalRegistry) -> List[int]: """Infer inputs for (A + B) × C expression circuit (parenthetical grouping). Circuit structure: - Add stage: sum = A + B - Mask stage: mask.s[stage].b[bit] = sum[bit] AND C[stage] - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage) Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB) """ prefix = "arithmetic.expr_paren" # Register all inputs for i in range(8): reg.register(f"$a[{i}]") reg.register(f"$b[{i}]") reg.register(f"$c[{i}]") # Add stage: sum = A + B if '.add.fa' in gate and '.mul.' not in gate: m = re.search(r'\.fa(\d+)\.', gate) if not m: return [] bit = int(m.group(1)) # Inputs: $a[7-bit], $b[7-bit] a_input = reg.get_id(f"$a[{7-bit}]") b_input = reg.get_id(f"$b[{7-bit}]") # Carry input if bit == 0: cin = reg.get_id("#0") else: cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or") fa_prefix = f"{prefix}.add.fa{bit}" if '.ha1.sum.layer1' in gate: return [a_input, b_input] if '.ha1.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] if '.ha1.carry' in gate and '.layer' not in gate: return [a_input, b_input] if '.ha2.sum.layer1' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.ha2.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] if '.ha2.carry' in gate and '.layer' not in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.carry_or' in gate: return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] return [] # Mask AND gates: mask.s[stage].b[bit] = sum[bit] AND C[stage] if '.mul.mask.' in gate: m = re.search(r'\.s(\d+)\.b(\d+)', gate) if m: stage = int(m.group(1)) bit = int(m.group(2)) # sum[bit] comes from add stage output sum_input = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2") # C[stage] in MSB-first: $c[7-stage] c_input = reg.get_id(f"$c[{7-stage}]") return [sum_input, c_input] return [] # Accumulator adders: acc.s[stage].fa[bit] if '.mul.acc.' in gate: m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate) if not m: return [] stage = int(m.group(1)) # 1-7 bit = int(m.group(2)) # 0-7 # A input: previous stage output if stage == 1: # First accumulator: A = mask.s0.b[bit] (AND gate output) a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}") else: # Later stages: A = previous accumulator sum a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2") # B input: (mask.s[stage] << stage)[bit] if bit < stage: b_input = reg.get_id("#0") else: b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}") # Carry input if bit == 0: cin = reg.get_id("#0") else: cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or") fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}" if '.ha1.sum.layer1' in gate: return [a_input, b_input] if '.ha1.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] if '.ha1.carry' in gate and '.layer' not in gate: return [a_input, b_input] if '.ha2.sum.layer1' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.ha2.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] if '.ha2.carry' in gate and '.layer' not in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.carry_or' in gate: return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] return [] return [] def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]: """Infer inputs for 3-operand adder: A + B + C.""" prefix = "arithmetic.add3_8bit" # Register all inputs for i in range(8): reg.register(f"$a[{i}]") reg.register(f"$b[{i}]") reg.register(f"$c[{i}]") # Parse stage and bit if '.stage1.' in gate: m = re.search(r'\.fa(\d+)\.', gate) if not m: return [] bit = int(m.group(1)) # Stage 1: A + B (LSB is index 7 in MSB-first) a_bit = reg.get_id(f"$a[{7-bit}]") b_bit = reg.get_id(f"$b[{7-bit}]") cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.stage1.fa{bit-1}.carry_or") fa_prefix = f"{prefix}.stage1.fa{bit}" elif '.stage2.' in gate: m = re.search(r'\.fa(\d+)\.', gate) if not m: return [] bit = int(m.group(1)) # Stage 2: stage1_result + C temp_bit = reg.register(f"{prefix}.stage1.fa{bit}.ha2.sum.layer2") c_bit = reg.get_id(f"$c[{7-bit}]") cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.stage2.fa{bit-1}.carry_or") a_bit = temp_bit b_bit = c_bit fa_prefix = f"{prefix}.stage2.fa{bit}" else: return [] if '.ha1.sum.layer1' in gate: return [a_bit, b_bit] if '.ha1.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")] if '.ha1.carry' in gate and '.layer' not in gate: return [a_bit, b_bit] if '.ha2.sum.layer1' in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.ha2.sum.layer2' in gate: return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")] if '.ha2.carry' in gate and '.layer' not in gate: return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin] if '.carry_or' in gate: return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")] return [] def infer_adcsbc_inputs(gate: str, prefix: str, is_sub: bool, reg: SignalRegistry) -> List[int]: for i in range(8): reg.register(f"{prefix}.$a[{i}]") reg.register(f"{prefix}.$b[{i}]") reg.register(f"{prefix}.$cin") if is_sub and '.notb' in gate: m = re.search(r'\.notb(\d+)', gate) if m: return [reg.get_id(f"{prefix}.$b[{int(m.group(1))}]")] return [] m = re.search(r'\.fa(\d+)\.', gate) if not m: return [] bit = int(m.group(1)) if is_sub: a_bit = reg.get_id(f"{prefix}.$a[{bit}]") notb = reg.register(f"{prefix}.notb{bit}") else: a_bit = reg.get_id(f"{prefix}.$a[{bit}]") notb = reg.get_id(f"{prefix}.$b[{bit}]") cin = reg.get_id(f"{prefix}.$cin") if bit == 0 else reg.register(f"{prefix}.fa{bit-1}.or_carry") fa_prefix = f"{prefix}.fa{bit}" if '.xor1.layer1' in gate: return [a_bit, notb if is_sub else reg.get_id(f"{prefix}.$b[{bit}]")] if '.xor1.layer2' in gate: return [reg.register(f"{fa_prefix}.xor1.layer1.or"), reg.register(f"{fa_prefix}.xor1.layer1.nand")] if '.xor2.layer1' in gate: return [reg.register(f"{fa_prefix}.xor1.layer2"), cin] if '.xor2.layer2' in gate: return [reg.register(f"{fa_prefix}.xor2.layer1.or"), reg.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, notb if is_sub else reg.get_id(f"{prefix}.$b[{bit}]")] if '.and2' in gate: return [reg.register(f"{fa_prefix}.xor1.layer2"), cin] if '.or_carry' in gate: return [reg.register(f"{fa_prefix}.and1"), reg.register(f"{fa_prefix}.and2")] return [] def infer_sub8bit_inputs(gate: str, reg: SignalRegistry) -> List[int]: prefix = "arithmetic.sub8bit" for i in range(8): reg.register(f"{prefix}.$a[{i}]") reg.register(f"{prefix}.$b[{i}]") if gate == f"{prefix}.carry_in": return [reg.get_id("#1")] if '.notb' in gate: m = re.search(r'\.notb(\d+)', gate) if m: return [reg.get_id(f"{prefix}.$b[{int(m.group(1))}]")] return [] m = re.search(r'\.fa(\d+)\.', gate) if not m: return [] bit = int(m.group(1)) a_bit = reg.get_id(f"{prefix}.$a[{bit}]") notb = reg.register(f"{prefix}.notb{bit}") cin = reg.get_id("#1") if bit == 0 else reg.register(f"{prefix}.fa{bit-1}.or_carry") fa_prefix = f"{prefix}.fa{bit}" if '.xor1.layer1' in gate: return [a_bit, notb] if '.xor1.layer2' in gate: return [reg.register(f"{fa_prefix}.xor1.layer1.or"), reg.register(f"{fa_prefix}.xor1.layer1.nand")] if '.xor2.layer1' in gate: return [reg.register(f"{fa_prefix}.xor1.layer2"), cin] if '.xor2.layer2' in gate: return [reg.register(f"{fa_prefix}.xor2.layer1.or"), reg.register(f"{fa_prefix}.xor2.layer1.nand")] if '.and1' in gate: return [a_bit, notb] if '.and2' in gate: return [reg.register(f"{fa_prefix}.xor1.layer2"), cin] if '.or_carry' in gate: return [reg.register(f"{fa_prefix}.and1"), reg.register(f"{fa_prefix}.and2")] return [] def infer_threshold_inputs(gate: str, reg: SignalRegistry) -> List[int]: for i in range(8): reg.register(f"$x[{i}]") return [reg.get_id(f"$x[{i}]") for i in range(8)] def infer_modular_inputs(gate: str, reg: SignalRegistry) -> List[int]: for i in range(8): reg.register(f"$x[{i}]") if '.layer1' in gate or '.layer2' in gate or '.layer3' in gate: if 'layer1.geq' in gate or 'layer1.leq' in gate: return [reg.get_id(f"$x[{i}]") for i in range(8)] if 'layer2.eq' in gate: m = re.search(r'layer2\.eq(\d+)', gate) if m: idx = m.group(1) parent = gate.rsplit('.layer2', 1)[0] return [reg.register(f"{parent}.layer1.geq{idx}"), reg.register(f"{parent}.layer1.leq{idx}")] if 'layer3.or' in gate: parent = gate.rsplit('.layer3', 1)[0] eq_gates = [] for i in range(256): eq_gate = f"{parent}.layer2.eq{i}" if eq_gate in reg.name_to_id: eq_gates.append(reg.get_id(eq_gate)) return eq_gates if eq_gates else [reg.get_id(f"$x[{i}]") for i in range(8)] return [reg.get_id(f"$x[{i}]") for i in range(8)] def infer_control_jump_inputs(gate: str, prefix: str, reg: SignalRegistry) -> List[int]: for i in range(8): reg.register(f"{prefix}.$pc[{i}]") reg.register(f"{prefix}.$target[{i}]") flag = "$cond" if "jz" in prefix: flag = "$zero" elif "jc" in prefix: flag = "$carry" elif "jn" in prefix and "jnc" not in prefix and "jnz" not in prefix and "jnv" not in prefix: flag = "$negative" elif "jv" in prefix and "jnv" not in prefix: flag = "$overflow" elif "jp" in prefix: flag = "$positive" elif "jnc" in prefix: flag = "$not_carry" elif "jnz" in prefix: flag = "$not_zero" elif "jnv" in prefix: flag = "$not_overflow" reg.register(f"{prefix}.{flag}") m = re.search(r'\.bit(\d+)\.', gate) if not m: return [] bit = int(m.group(1)) bit_prefix = f"{prefix}.bit{bit}" if '.not_sel' in gate: return [reg.get_id(f"{prefix}.{flag}")] if '.and_a' in gate: return [reg.get_id(f"{prefix}.$pc[{bit}]"), reg.register(f"{bit_prefix}.not_sel")] if '.and_b' in gate: return [reg.get_id(f"{prefix}.$target[{bit}]"), reg.get_id(f"{prefix}.{flag}")] if '.or' in gate: return [reg.register(f"{bit_prefix}.and_a"), reg.register(f"{bit_prefix}.and_b")] return [] def infer_buffer_inputs(gate: str, reg: SignalRegistry) -> List[int]: m = re.search(r'\.bit(\d+)$', gate) if m: bit = int(m.group(1)) prefix = gate.rsplit('.bit', 1)[0] return [reg.register(f"{prefix}.$data[{bit}]")] return [reg.register("$data")] def infer_memory_inputs(gate: str, reg: SignalRegistry) -> List[int]: if 'addr_decode' in gate: return [reg.register(f"$addr[{i}]") for i in range(16)] if 'read' in gate: return [reg.register("$mem"), reg.register("$sel")] if 'write' in gate: return [reg.register("$mem"), reg.register("$data"), reg.register("$sel"), reg.register("$we")] return [] def infer_alu_inputs(gate: str, reg: SignalRegistry) -> List[int]: for i in range(8): reg.register(f"$a[{i}]") reg.register(f"$b[{i}]") for i in range(4): reg.register(f"$opcode[{i}]") if 'alucontrol' in gate: return [reg.get_id(f"$opcode[{i}]") for i in range(4)] if 'aluflags' in gate: return [reg.register("$result"), reg.register("$carry"), reg.register("$overflow")] if '.shl.bit' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) if bit < 7: return [reg.get_id(f"$a[{bit + 1}]")] else: return [reg.get_id("#0")] return [reg.get_id(f"$a[{i}]") for i in range(8)] if '.shr.bit' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) if bit > 0: return [reg.get_id(f"$a[{bit - 1}]")] else: return [reg.get_id("#0")] return [reg.get_id(f"$a[{i}]") for i in range(8)] if '.mul.pp.a' in gate: m = re.search(r'a(\d+)b(\d+)', gate) if m: i, j = int(m.group(1)), int(m.group(2)) return [reg.get_id(f"$a[{i}]"), reg.get_id(f"$b[{j}]")] return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] if '.mul.' in gate: return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] if '.div.stage' in gate: if '.cmp' in gate: return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] if '.mux.bit' in gate: m = re.search(r'stage(\d+)\.mux\.bit(\d+)', gate) if m: stage, bit = int(m.group(1)), int(m.group(2)) prefix = f"alu.alu8bit.div.stage{stage}" if '.not_sel' in gate: return [reg.register(f"{prefix}.cmp")] if '.and_a' in gate: return [reg.register(f"$rem[{bit}]"), reg.register(f"{prefix}.mux.bit{bit}.not_sel")] if '.and_b' in gate: return [reg.register(f"$sub[{bit}]"), reg.register(f"{prefix}.cmp")] if '.or' in gate: return [reg.register(f"{prefix}.mux.bit{bit}.and_a"), reg.register(f"{prefix}.mux.bit{bit}.and_b")] return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] if '.inc.bit' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) prefix = f"alu.alu8bit.inc.bit{bit}" if 'layer1' in gate: if bit == 7: return [reg.get_id(f"$a[{bit}]"), reg.get_id("#1")] else: return [reg.get_id(f"$a[{bit}]"), reg.register(f"alu.alu8bit.inc.bit{bit+1}.carry")] if 'layer2' in gate: return [reg.register(f"{prefix}.xor.layer1.or"), reg.register(f"{prefix}.xor.layer1.nand")] if '.carry' in gate: if bit == 7: return [reg.get_id(f"$a[{bit}]"), reg.get_id("#1")] else: return [reg.get_id(f"$a[{bit}]"), reg.register(f"alu.alu8bit.inc.bit{bit+1}.carry")] return [reg.get_id(f"$a[{i}]") for i in range(8)] if '.dec.bit' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) prefix = f"alu.alu8bit.dec.bit{bit}" if '.not_a' in gate: return [reg.get_id(f"$a[{bit}]")] if 'layer1' in gate: if bit == 7: return [reg.get_id(f"$a[{bit}]"), reg.get_id("#1")] else: return [reg.get_id(f"$a[{bit}]"), reg.register(f"alu.alu8bit.dec.bit{bit+1}.borrow")] if 'layer2' in gate: return [reg.register(f"{prefix}.xor.layer1.or"), reg.register(f"{prefix}.xor.layer1.nand")] if '.borrow' in gate: if bit == 7: return [reg.register(f"{prefix}.not_a"), reg.get_id("#1")] else: return [reg.register(f"{prefix}.not_a"), reg.register(f"alu.alu8bit.dec.bit{bit+1}.borrow")] return [reg.get_id(f"$a[{i}]") for i in range(8)] if '.neg.' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) if '.not.bit' in gate: return [reg.get_id(f"$a[{bit}]")] prefix = f"alu.alu8bit.neg.inc.bit{bit}" not_bit = f"alu.alu8bit.neg.not.bit{bit}" if 'layer1' in gate: if bit == 7: return [reg.register(not_bit), reg.get_id("#1")] else: return [reg.register(not_bit), reg.register(f"alu.alu8bit.neg.inc.bit{bit+1}.carry")] if 'layer2' in gate: return [reg.register(f"{prefix}.xor.layer1.or"), reg.register(f"{prefix}.xor.layer1.nand")] if '.carry' in gate: if bit == 7: return [reg.register(not_bit), reg.get_id("#1")] else: return [reg.register(not_bit), reg.register(f"alu.alu8bit.neg.inc.bit{bit+1}.carry")] return [reg.get_id(f"$a[{i}]") for i in range(8)] if '.rol.bit' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) src = (bit + 1) % 8 return [reg.get_id(f"$a[{src}]")] return [reg.get_id(f"$a[{i}]") for i in range(8)] if '.ror.bit' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) src = (bit - 1) % 8 return [reg.get_id(f"$a[{src}]")] return [reg.get_id(f"$a[{i}]") for i in range(8)] if '.and' in gate or '.or' in gate or '.xor' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) return [reg.get_id(f"$a[{bit}]"), reg.get_id(f"$b[{bit}]")] return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] if '.not' in gate: m = re.search(r'bit(\d+)', gate) if m: return [reg.get_id(f"$a[{int(m.group(1))}]")] return [reg.get_id(f"$a[{i}]") for i in range(8)] if 'layer1' in gate or 'layer2' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) if 'layer1' in gate: return [reg.get_id(f"$a[{bit}]"), reg.get_id(f"$b[{bit}]")] parent = gate.rsplit('.layer2', 1)[0] return [reg.register(f"{parent}.layer1.or"), reg.register(f"{parent}.layer1.nand")] return [reg.get_id(f"$a[{i}]") for i in range(8)] def infer_pattern_inputs(gate: str, reg: SignalRegistry) -> List[int]: for i in range(8): reg.register(f"$x[{i}]") if 'hammingdistance' in gate: for i in range(8): reg.register(f"$a[{i}]") reg.register(f"$b[{i}]") return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] return [reg.get_id(f"$x[{i}]") for i in range(8)] def infer_error_detection_inputs(gate: str, reg: SignalRegistry) -> List[int]: for i in range(8): reg.register(f"$x[{i}]") if 'hamming' in gate: if 'encode' in gate: for i in range(4): reg.register(f"$d[{i}]") return [reg.get_id(f"$d[{i}]") for i in range(4)] if 'decode' in gate or 'syndrome' in gate: for i in range(7): reg.register(f"$c[{i}]") return [reg.get_id(f"$c[{i}]") for i in range(7)] if 'crc' in gate: return [reg.register(f"$data[{i}]") for i in range(8)] if 'parity' in gate and 'stage' in gate: m = re.search(r'stage(\d+)\.xor(\d+)', gate) if m: stage = int(m.group(1)) idx = int(m.group(2)) if stage == 1: return [reg.get_id(f"$x[{2*idx}]"), reg.get_id(f"$x[{2*idx+1}]")] parent = gate.rsplit(f'.stage{stage}', 1)[0] prev_stage = stage - 1 return [ reg.register(f"{parent}.stage{prev_stage}.xor{2*idx}.layer2"), reg.register(f"{parent}.stage{prev_stage}.xor{2*idx+1}.layer2") ] if 'output.not' in gate: parent = gate.rsplit('.output', 1)[0] return [reg.register(f"{parent}.stage3.xor0.layer2")] return [reg.get_id(f"$x[{i}]") for i in range(8)] def infer_combinational_inputs(gate: str, reg: SignalRegistry) -> List[int]: if 'decoder3to8' in gate: for i in range(3): reg.register(f"$sel[{i}]") return [reg.get_id(f"$sel[{i}]") for i in range(3)] if 'encoder8to3' in gate: for i in range(8): reg.register(f"$x[{i}]") return [reg.get_id(f"$x[{i}]") for i in range(8)] if 'multiplexer' in gate: if '2to1' in gate: return [reg.register("$a"), reg.register("$b"), reg.register("$sel")] if '4to1' in gate: return [reg.register(f"$x[{i}]") for i in range(4)] + [reg.register(f"$sel[{i}]") for i in range(2)] if '8to1' in gate: return [reg.register(f"$x[{i}]") for i in range(8)] + [reg.register(f"$sel[{i}]") for i in range(3)] if 'demultiplexer' in gate: return [reg.register("$x"), reg.register("$sel")] if 'regmux4to1' in gate: for r in range(4): for i in range(8): reg.register(f"$r{r}[{i}]") for i in range(2): reg.register(f"$sel[{i}]") if gate == "combinational.regmux4to1.not_s0": return [reg.get_id("$sel[0]")] if gate == "combinational.regmux4to1.not_s1": return [reg.get_id("$sel[1]")] m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) if '.not_s' in gate: sidx = 0 if 's0' in gate else 1 return [reg.get_id(f"$sel[{sidx}]")] if '.and' in gate: and_m = re.search(r'\.and(\d+)', gate) if and_m: and_idx = int(and_m.group(1)) sel0 = "combinational.regmux4to1.not_s0" if (and_idx & 1) == 0 else "$sel[0]" sel1 = "combinational.regmux4to1.not_s1" if (and_idx & 2) == 0 else "$sel[1]" return [reg.get_id(f"$r{and_idx}[{bit}]"), reg.register(sel0), reg.register(sel1)] if '.or' in gate: return [reg.register(f"combinational.regmux4to1.bit{bit}.and{i}") for i in range(4)] return [] if 'barrelshifter' in gate: for i in range(8): reg.register(f"$x[{i}]") for i in range(3): reg.register(f"$shift[{i}]") m = re.search(r'layer(\d+)\.bit(\d+)', gate) if m: layer, bit = int(m.group(1)), int(m.group(2)) shift_amount = 1 << (2 - layer) prefix = f"combinational.barrelshifter.layer{layer}.bit{bit}" if '.not_sel' in gate: return [reg.get_id(f"$shift[{2 - layer}]")] if '.and_a' in gate: if layer == 0: return [reg.get_id(f"$x[{bit}]"), reg.register(f"{prefix}.not_sel")] else: prev_prefix = f"combinational.barrelshifter.layer{layer-1}.bit{bit}" return [reg.register(f"{prev_prefix}.or"), reg.register(f"{prefix}.not_sel")] if '.and_b' in gate: src = (bit + shift_amount) % 8 if layer == 0: return [reg.get_id(f"$x[{src}]"), reg.get_id(f"$shift[{2 - layer}]")] else: prev_prefix = f"combinational.barrelshifter.layer{layer-1}.bit{src}" return [reg.register(f"{prev_prefix}.or"), reg.get_id(f"$shift[{2 - layer}]")] if '.or' in gate: return [reg.register(f"{prefix}.and_a"), reg.register(f"{prefix}.and_b")] return [reg.get_id(f"$x[{i}]") for i in range(8)] if 'priorityencoder' in gate: for i in range(8): reg.register(f"$x[{i}]") if '.any_ge' in gate: m = re.search(r'any_ge(\d+)', gate) if m: pos = int(m.group(1)) return [reg.get_id(f"$x[{i}]") for i in range(pos, 8)] if '.is_highest' in gate: m = re.search(r'is_highest(\d+)', gate) if m: pos = int(m.group(1)) if '.not_higher' in gate: if pos == 0: return [reg.get_id("#0")] else: return [reg.register(f"combinational.priorityencoder.any_ge{pos-1}")] if '.and' in gate: return [reg.get_id(f"$x[{pos}]"), reg.register(f"combinational.priorityencoder.is_highest{pos}.not_higher")] if '.out' in gate: m = re.search(r'out(\d+)', gate) if m: out_bit = int(m.group(1)) inputs = [] for pos in range(8): if (pos >> out_bit) & 1: inputs.append(reg.register(f"combinational.priorityencoder.is_highest{pos}.and")) return inputs if '.valid' in gate: return [reg.get_id(f"$x[{i}]") for i in range(8)] return [reg.get_id(f"$x[{i}]") for i in range(8)] return [] def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor]) -> List[int]: if gate.startswith('manifest.'): return [] if gate.startswith('boolean.'): return infer_boolean_inputs(gate, reg) if gate.startswith('arithmetic.'): if 'halfadder' in gate: return infer_halfadder_inputs(gate, "arithmetic.halfadder", reg) if 'fulladder' in gate: return infer_fulladder_inputs(gate, "arithmetic.fulladder", reg) if 'ripplecarry2bit' in gate: return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry2bit", 2, reg) if 'ripplecarry4bit' in gate: return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry4bit", 4, reg) if 'ripplecarry8bit' in gate: return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg) if 'add3_8bit' in gate: return infer_add3_inputs(gate, reg) if 'expr_add_mul' in gate and 'paren' not in gate: return infer_expr_add_mul_inputs(gate, reg) if 'expr_paren_add_mul' in gate: return infer_expr_paren_add_mul_inputs(gate, reg) if 'adc8bit' in gate: return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg) if 'sbc8bit' in gate: return infer_adcsbc_inputs(gate, "arithmetic.sbc8bit", True, reg) if 'sub8bit' in gate: return infer_sub8bit_inputs(gate, reg) if any(cmp in gate for cmp in ['greaterthan8bit', 'lessthan8bit', 'greaterorequal8bit', 'lessorequal8bit']): for i in range(8): reg.register(f"$a[{i}]") reg.register(f"$b[{i}]") return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] if 'equality8bit' in gate: for i in range(8): reg.register(f"$a[{i}]") reg.register(f"$b[{i}]") if 'layer1' in gate: return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] if 'layer2' in gate: return [reg.register("arithmetic.equality8bit.layer1.geq"), reg.register("arithmetic.equality8bit.layer1.leq")] return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)] for i in range(8): reg.register(f"$a[{i}]") reg.register(f"$b[{i}]") return [reg.get_id(f"$a[{i}]") for i in range(8)] if gate.startswith('threshold.'): return infer_threshold_inputs(gate, reg) if gate.startswith('modular.'): return infer_modular_inputs(gate, reg) if gate.startswith('control.'): if any(j in gate for j in ['jz', 'jc', 'jn', 'jv', 'jp', 'jnz', 'jnc', 'jnv', 'conditionaljump']): prefix = gate.split('.bit')[0] if '.bit' in gate else gate.rsplit('.', 1)[0] return infer_control_jump_inputs(gate, prefix, reg) if any(b in gate for b in ['fetch', 'load', 'store', 'mem_addr']): return infer_buffer_inputs(gate, reg) if 'push.sp_dec' in gate or 'pop.sp_inc' in gate: for i in range(16): reg.register(f"$sp[{i}]") m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) op = 'push.sp_dec' if 'push' in gate else 'pop.sp_inc' prefix = f"control.{op}.bit{bit}" if 'layer1' in gate: if bit == 15: return [reg.get_id(f"$sp[{bit}]"), reg.get_id("#1")] else: carry_name = 'borrow' if 'push' in gate else 'carry' return [reg.get_id(f"$sp[{bit}]"), reg.register(f"control.{op}.bit{bit+1}.{carry_name}")] if 'layer2' in gate: return [reg.register(f"{prefix}.xor.layer1.or"), reg.register(f"{prefix}.xor.layer1.nand")] if '.borrow' in gate or '.carry' in gate: if bit == 15: return [reg.get_id(f"$sp[{bit}]"), reg.get_id("#1")] else: carry_name = 'borrow' if 'push' in gate else 'carry' return [reg.get_id(f"$sp[{bit}]"), reg.register(f"control.{op}.bit{bit+1}.{carry_name}")] return [reg.get_id(f"$sp[{i}]") for i in range(16)] if 'ret.addr' in gate: m = re.search(r'bit(\d+)', gate) if m: bit = int(m.group(1)) return [reg.register(f"$ret_addr[{bit}]")] return [reg.register(f"$ret_addr[{i}]") for i in range(16)] return [reg.register("$ctrl")] if gate.startswith('memory.'): return infer_memory_inputs(gate, reg) if gate.startswith('alu.'): return infer_alu_inputs(gate, reg) if gate.startswith('pattern_recognition.'): return infer_pattern_inputs(gate, reg) if gate.startswith('error_detection.'): return infer_error_detection_inputs(gate, reg) if gate.startswith('combinational.'): return infer_combinational_inputs(gate, reg) weight_key = f"{gate}.weight" if weight_key in tensors: w = tensors[weight_key] n_inputs = w.shape[0] if w.dim() == 1 else w.shape[-1] for i in range(n_inputs): reg.register(f"$input[{i}]") return [reg.get_id(f"$input[{i}]") for i in range(n_inputs)] return [] def build_inputs(tensors: Dict[str, torch.Tensor]) -> tuple[Dict[str, torch.Tensor], SignalRegistry, dict]: reg = SignalRegistry() gates = get_all_gates(tensors) stats = {"added": 0, "skipped": 0, "empty": 0} for gate in sorted(gates): inputs_key = f"{gate}.inputs" if inputs_key in tensors: stats["skipped"] += 1 continue inputs = infer_inputs_for_gate(gate, reg, tensors) if inputs: tensors[inputs_key] = torch.tensor(inputs, dtype=torch.int64) stats["added"] += 1 else: stats["empty"] += 1 return tensors, reg, stats def resolve_memory_config(args) -> tuple: """Resolve memory configuration from args, returns (addr_bits, mem_bytes).""" if hasattr(args, 'memory_profile') and args.memory_profile: addr_bits = MEMORY_PROFILES[args.memory_profile] elif hasattr(args, 'addr_bits') and args.addr_bits is not None: addr_bits = args.addr_bits else: addr_bits = DEFAULT_ADDR_BITS mem_bytes = (1 << addr_bits) if addr_bits > 0 else 0 return addr_bits, mem_bytes def cmd_memory(args) -> None: addr_bits, mem_bytes = resolve_memory_config(args) print("=" * 60) print(" BUILD MEMORY CIRCUITS") print("=" * 60) print(f"\nMemory configuration:") print(f" Address bits: {addr_bits}") print(f" Memory bytes: {mem_bytes:,}") if addr_bits == 0: print(f" Mode: PURE ALU (no memory)") elif addr_bits <= 4: print(f" Mode: LLM registers") elif addr_bits <= 8: print(f" Mode: LLM scratchpad") elif addr_bits <= 12: print(f" Mode: Reduced CPU") else: print(f" Mode: Full CPU") print(f"\nLoading: {args.model}") tensors = load_tensors(args.model) print(f" Loaded {len(tensors)} tensors") print("\nDropping existing memory/control tensors...") drop_prefixes(tensors, [ "memory.addr_decode.", "memory.read.", "memory.write.", "control.fetch.ir.", "control.load.", "control.store.", "control.mem_addr.", ]) print(f" Now {len(tensors)} tensors") if addr_bits > 0: print("\nGenerating memory circuits...") add_decoder(tensors, addr_bits, mem_bytes) add_memory_read_mux(tensors, mem_bytes) add_memory_write_cells(tensors, mem_bytes) print(" Added decoder, read mux, write cells") print("\nGenerating buffer gates...") try: add_fetch_load_store_buffers(tensors, addr_bits) print(" Added fetch/load/store/mem_addr buffers") except ValueError as e: print(f" Buffers already exist: {e}") else: print("\nSkipping memory circuits (addr_bits=0, pure ALU mode)") print("\nUpdating manifest...") update_manifest(tensors, addr_bits, mem_bytes) print(f" memory_bytes={mem_bytes:,}, pc_width={addr_bits}") if args.apply: print(f"\nSaving: {args.model}") save_file(tensors, str(args.model)) if args.manifest: write_manifest(MANIFEST_PATH, tensors) print(f" Wrote manifest: {MANIFEST_PATH}") print(" Done.") else: print("\n[DRY-RUN] Use --apply to save.") print(f"\nTotal: {len(tensors)} tensors") mem_params = sum(t.numel() for k, t in tensors.items() if k.startswith("memory.")) alu_params = sum(t.numel() for k, t in tensors.items() if not k.startswith("memory.") and not k.startswith("manifest.")) print(f" Memory params: {mem_params:,}") print(f" ALU/Logic params: {alu_params:,}") print("=" * 60) def cmd_inputs(args) -> None: print("=" * 60) print(" BUILD .inputs TENSORS") print("=" * 60) print(f"\nLoading: {args.model}") tensors = load_tensors(args.model) print(f" Loaded {len(tensors)} tensors") gates = get_all_gates(tensors) print(f" Found {len(gates)} gates") print("\nBuilding .inputs tensors...") tensors, reg, stats = build_inputs(tensors) print(f"\nResults:") print(f" Added: {stats['added']}") print(f" Skipped: {stats['skipped']}") print(f" Empty: {stats['empty']}") print(f" Signals: {len(reg.name_to_id)}") print(f" Total: {len(tensors)}") if args.apply: print(f"\nSaving: {args.model}") metadata = {"signal_registry": reg.to_metadata()} save_file(tensors, str(args.model), metadata=metadata) print(" Done.") else: print("\n[DRY-RUN] Use --apply to save.") print("=" * 60) def cmd_alu(args) -> None: print("=" * 60) print(" BUILD ALU CIRCUITS") print("=" * 60) print(f"\nLoading: {args.model}") tensors = load_tensors(args.model) print(f" Loaded {len(tensors)} tensors") print("\nDropping existing ALU extension tensors...") drop_prefixes(tensors, [ "alu.alu8bit.shl.", "alu.alu8bit.shr.", "alu.alu8bit.mul.", "alu.alu8bit.div.", "alu.alu8bit.inc.", "alu.alu8bit.dec.", "alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.", "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.", "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.", "arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.", "control.push.", "control.pop.", "control.ret.", "combinational.barrelshifter.", "combinational.priorityencoder.", ]) print(f" Now {len(tensors)} tensors") print("\nGenerating SHL/SHR circuits...") try: add_shl_shr(tensors) print(" Added SHL (8 gates), SHR (8 gates)") except ValueError as e: print(f" SHL/SHR already exist: {e}") print("\nGenerating MUL circuit...") try: add_mul(tensors) print(" Added MUL (64 partial product AND gates)") except ValueError as e: print(f" MUL already exists: {e}") print("\nGenerating DIV circuit...") try: add_div(tensors) print(" Added DIV (8 stages x comparison + mux)") except ValueError as e: print(f" DIV already exists: {e}") print("\nGenerating INC/DEC circuits...") try: add_inc_dec(tensors) print(" Added INC (32 gates), DEC (40 gates)") except ValueError as e: print(f" INC/DEC already exist: {e}") print("\nGenerating NEG circuit...") try: add_neg(tensors) print(" Added NEG (40 gates)") except ValueError as e: print(f" NEG already exists: {e}") print("\nGenerating ROL/ROR circuits...") try: add_rol_ror(tensors) print(" Added ROL (8 gates), ROR (8 gates)") except ValueError as e: print(f" ROL/ROR already exist: {e}") print("\nGenerating stack operation circuits...") try: add_stack_ops(tensors) print(" Added PUSH/POP/RET (144 gates)") except ValueError as e: print(f" Stack ops already exist: {e}") print("\nGenerating barrel shifter...") try: add_barrel_shifter(tensors) print(" Added barrel shifter (96 gates)") except ValueError as e: print(f" Barrel shifter already exists: {e}") print("\nGenerating priority encoder...") try: add_priority_encoder(tensors) print(" Added priority encoder (28 gates)") except ValueError as e: print(f" Priority encoder already exists: {e}") print("\nGenerating comparator circuits...") try: add_comparators(tensors) print(" Added GT, GE, LT, LE (single-layer), EQ (two-layer)") except ValueError as e: print(f" Comparators already exist: {e}") print("\nGenerating 3-operand adder circuit...") try: add_add3(tensors) print(" Added ADD3 (16 full adders = 144 gates)") except ValueError as e: print(f" ADD3 already exists: {e}") print("\nGenerating expression A + B × C circuit...") try: add_expr_add_mul(tensors) print(" Added EXPR_ADD_MUL (64 AND + 56 + 8 full adders = 640 gates)") except ValueError as e: print(f" EXPR_ADD_MUL already exists: {e}") print("\nGenerating expression (A + B) × C circuit...") try: add_expr_paren(tensors) print(" Added EXPR_PAREN (8 + 64 AND + 56 full adders = 640 gates)") except ValueError as e: print(f" EXPR_PAREN already exists: {e}") if args.apply: print(f"\nSaving: {args.model}") save_file(tensors, str(args.model)) print(" Done.") else: print("\n[DRY-RUN] Use --apply to save.") print(f"\nTotal: {len(tensors)} tensors") print("=" * 60) def cmd_all(args) -> None: print("Running: memory") cmd_memory(args) print("\nRunning: alu") cmd_alu(args) print("\nRunning: inputs") cmd_inputs(args) def main() -> None: parser = argparse.ArgumentParser( description="Build tools for threshold computer safetensors", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Memory Profiles: full 64KB (16-bit addr) - Full CPU mode reduced 4KB (12-bit addr) - Reduced CPU scratchpad 256B (8-bit addr) - LLM scratchpad registers 16B (4-bit addr) - LLM register file none 0B (no memory) - Pure ALU for LLM Examples: python build.py memory --memory-profile none --apply # LLM-only (no RAM) python build.py memory --memory-profile scratchpad # 256-byte scratchpad python build.py memory --addr-bits 6 # Custom: 64 bytes python build.py memory # Default: 64KB """ ) parser.add_argument("--model", type=Path, default=MODEL_PATH, help="Model path") parser.add_argument("--apply", action="store_true", help="Apply changes (default: dry-run)") parser.add_argument("--manifest", action="store_true", help="Write tensors.txt manifest (memory only)") mem_group = parser.add_mutually_exclusive_group() mem_group.add_argument( "--memory-profile", "-m", choices=list(MEMORY_PROFILES.keys()), help="Memory size profile (full/reduced/scratchpad/registers/none)" ) mem_group.add_argument( "--addr-bits", "-a", type=int, choices=range(0, 17), metavar="N", help="Address bus width in bits (0-16). 0=no memory, 16=64KB" ) subparsers = parser.add_subparsers(dest="command", help="Subcommands") subparsers.add_parser("memory", help="Generate memory circuits (size controlled by --memory-profile or --addr-bits)") subparsers.add_parser("alu", help="Generate ALU extension circuits (SHL, SHR, comparators)") subparsers.add_parser("inputs", help="Add .inputs metadata tensors") subparsers.add_parser("all", help="Run memory, alu, then inputs") args = parser.parse_args() if args.command == "memory": cmd_memory(args) elif args.command == "alu": cmd_alu(args) elif args.command == "inputs": cmd_inputs(args) elif args.command == "all": cmd_all(args) else: parser.print_help() if __name__ == "__main__": main()