""" © SupraLabs 2026 - Inference script for Chimera-50M Reasoning """ import torch from tokenizers import ByteLevelBPETokenizer from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM MODEL_ID = "./Chimera-50M-Reasoning-FINAL" MAX_NEW_TOKENS = 1500 print("[*] Loading tokenizer...") fast_tokenizer = ByteLevelBPETokenizer( "custom_llama_tokenizer-vocab.json", "custom_llama_tokenizer-merges.txt" ) tokenizer = PreTrainedTokenizerFast( tokenizer_object=fast_tokenizer, bos_token="", eos_token="", unk_token="", pad_token="", ) print(f"[*] Loading model from {MODEL_ID}...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto", ) model.eval() print(f"[+] Model loaded — {model.num_parameters():,} parameters") SYSTEM_PROMPT = ( "Your role as an assistant involves thoroughly exploring questions through " "a systematic long thinking process before providing the final precise and " "accurate solutions." ) def build_prompt(question: str) -> str: return ( f"[SYSTEM]: {SYSTEM_PROMPT}\n\n" f"[USER]: {question}\n\n" f"[ASSISTANT]: <|begin_of_thought|>\n" ) def generate(question: str, max_new_tokens: int = MAX_NEW_TOKENS) -> str: prompt = build_prompt(question) input_ids = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt") input_ids = input_ids.to(model.device) prompt_len = input_ids.shape[1] with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.3, top_k=25, top_p=0.8, repetition_penalty=1.3, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) response_ids = output_ids[0][prompt_len:] raw = tokenizer.decode(response_ids, skip_special_tokens=False).strip() raw = raw.replace("", "").replace("", "").strip() return "<|begin_of_thought|>\n" + raw if __name__ == "__main__": print("\n[+] Ready. Type 'quit' to exit.\n") while True: question = input("Question: ").strip() if question.lower() == "quit": break print("=" * 50) print(generate(question)) print()