""" Run text generation inference on the exported Qwen3.5-0.8B TFLite model. Usage: python inference_tflite.py --model_path output/qwen35_0.8b/qwen35_q8_ekv2048.tflite python inference_tflite.py --prompt "Explain gravity" --max_new_tokens 100 """ import argparse import glob import logging import time import numpy as np import transformers from ai_edge_litert import interpreter as tfl_interpreter logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", ) logger = logging.getLogger(__name__) # Architecture constants (must match qwen35_model.py) NUM_LAYERS = 24 LAYER_TYPES = [ "linear", "linear", "linear", "full", "linear", "linear", "linear", "full", "linear", "linear", "linear", "full", "linear", "linear", "linear", "full", "linear", "linear", "linear", "full", "linear", "linear", "linear", "full", ] LINEAR_QKV_DIM = 6144 LINEAR_CONV_KERNEL = 4 LINEAR_NUM_HEADS = 16 LINEAR_K_HEAD_DIM = 128 LINEAR_V_HEAD_DIM = 128 FULL_ATTN_NUM_KV_HEADS = 2 FULL_ATTN_HEAD_DIM = 256 MODEL_ID = "Qwen/Qwen3.5-0.8B" def create_initial_kv_cache(kv_cache_max_len, batch_size=1): """Create zero-initialized KV cache arrays matching the model's per-layer shapes.""" kv = {} for i in range(NUM_LAYERS): if LAYER_TYPES[i] == "linear": kv[f"kv_cache_k_{i}"] = np.zeros( (batch_size, LINEAR_QKV_DIM, LINEAR_CONV_KERNEL - 1), dtype=np.float32, ) kv[f"kv_cache_v_{i}"] = np.zeros( (batch_size, LINEAR_NUM_HEADS, LINEAR_K_HEAD_DIM, LINEAR_V_HEAD_DIM), dtype=np.float32, ) else: kv[f"kv_cache_k_{i}"] = np.zeros( (batch_size, kv_cache_max_len, FULL_ATTN_NUM_KV_HEADS, FULL_ATTN_HEAD_DIM), dtype=np.float32, ) kv[f"kv_cache_v_{i}"] = np.zeros( (batch_size, kv_cache_max_len, FULL_ATTN_NUM_KV_HEADS, FULL_ATTN_HEAD_DIM), dtype=np.float32, ) return kv def find_prefill_signature(signatures, seq_len): """Find the best prefill signature for the given sequence length.""" prefill_sigs = sorted( [s for s in signatures if s.startswith("prefill_")], key=lambda s: int(s.split("_")[1]), ) if not prefill_sigs: raise ValueError("No prefill signatures found in model") for sig in prefill_sigs: sig_len = int(sig.split("_")[1]) if sig_len >= seq_len: return sig, sig_len # Use largest available largest = prefill_sigs[-1] return largest, int(largest.split("_")[1]) def generate(model_path, prompt, max_new_tokens, kv_cache_max_len): """Run text generation with the TFLite model.""" # Load tokenizer logger.info("Loading tokenizer from: %s", MODEL_ID) tokenizer = transformers.AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True ) # Tokenize prompt input_ids = tokenizer.encode(prompt) logger.info("Prompt: %s", prompt) logger.info("Token count: %d", len(input_ids)) # Load TFLite model logger.info("Loading TFLite model from: %s", model_path) t0 = time.time() interp = tfl_interpreter.Interpreter(model_path=model_path) interp.allocate_tensors() logger.info("Model loaded in %.1fs", time.time() - t0) signatures = interp.get_signature_list() logger.info("Available signatures: %s", list(signatures.keys())) # Initialize KV cache kv_cache = create_initial_kv_cache(kv_cache_max_len) # --- Prefill phase --- sig_name, sig_len = find_prefill_signature(signatures, len(input_ids)) logger.info("Using prefill signature: %s (padding %d -> %d)", sig_name, len(input_ids), sig_len) # Pad input to match signature length padded_ids = input_ids + [0] * (sig_len - len(input_ids)) tokens = np.array([padded_ids], dtype=np.int32) input_pos = np.arange(sig_len, dtype=np.int32) prefill_runner = interp.get_signature_runner(sig_name) t0 = time.time() prefill_out = prefill_runner(tokens=tokens, input_pos=input_pos, **kv_cache) prefill_time = time.time() - t0 logger.info("Prefill done in %.2fs", prefill_time) # Update KV cache from prefill output for key in kv_cache: if key in prefill_out: kv_cache[key] = prefill_out[key] # --- Decode phase --- # Prefill processed sig_len tokens (including padding). Next decode # position is sig_len. We feed the last real token to get the first # generated token. decode_runner = interp.get_signature_runner("decode") generated_ids = list(input_ids) current_pos = sig_len # continue after prefill logger.info("Starting decode (max %d tokens)...", max_new_tokens) print(f"\n--- Generated text ---\n{prompt}", end="", flush=True) t0 = time.time() for step in range(max_new_tokens): # Feed last token, get next tok = np.array([[generated_ids[-1]]], dtype=np.int32) pos = np.array([current_pos], dtype=np.int32) decode_out = decode_runner(tokens=tok, input_pos=pos, **kv_cache) # Update KV cache for key in kv_cache: if key in decode_out: kv_cache[key] = decode_out[key] next_token = int(np.argmax(decode_out["logits"][0, -1])) generated_ids.append(next_token) current_pos += 1 # Print token word = tokenizer.decode([next_token]) print(word, end="", flush=True) # Stop on EOS if next_token == tokenizer.eos_token_id: break decode_time = time.time() - t0 num_decoded = len(generated_ids) - len(input_ids) print(f"\n\n--- Stats ---") print(f"Prefill: {prefill_time:.2f}s ({len(input_ids)} tokens)") print(f"Decode: {decode_time:.2f}s ({num_decoded} tokens, {num_decoded/decode_time:.1f} tok/s)") def main(): parser = argparse.ArgumentParser(description="TFLite inference for Qwen3.5-0.8B") parser.add_argument( "--model_path", default=None, help="Path to .tflite model file", ) parser.add_argument( "--prompt", default="What is the meaning of life?", help="Input prompt", ) parser.add_argument( "--max_new_tokens", type=int, default=50, help="Maximum tokens to generate", ) parser.add_argument( "--kv_cache_max_len", type=int, default=2048, help="KV cache max length (must match exported model)", ) args = parser.parse_args() # Auto-find model if not specified if args.model_path is None: files = glob.glob("output/**/*.tflite", recursive=True) if files: args.model_path = max(files, key=lambda f: __import__("os").path.getmtime(f)) logger.info("Auto-found model: %s", args.model_path) else: raise FileNotFoundError("No .tflite files found in output/") generate(args.model_path, args.prompt, args.max_new_tokens, args.kv_cache_max_len) if __name__ == "__main__": main()