"""
© SupraLabs 2026 - Inference script for Chimera-50M Reasoning
"""
import torch
from tokenizers import ByteLevelBPETokenizer
from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM
MODEL_ID = "./Chimera-50M-Reasoning-FINAL"
MAX_NEW_TOKENS = 1500
print("[*] Loading tokenizer...")
fast_tokenizer = ByteLevelBPETokenizer(
"custom_llama_tokenizer-vocab.json",
"custom_llama_tokenizer-merges.txt"
)
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=fast_tokenizer,
bos_token="",
eos_token="",
unk_token="",
pad_token="",
)
print(f"[*] Loading model from {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
model.eval()
print(f"[+] Model loaded — {model.num_parameters():,} parameters")
SYSTEM_PROMPT = (
"Your role as an assistant involves thoroughly exploring questions through "
"a systematic long thinking process before providing the final precise and "
"accurate solutions."
)
def build_prompt(question: str) -> str:
return (
f"[SYSTEM]: {SYSTEM_PROMPT}\n\n"
f"[USER]: {question}\n\n"
f"[ASSISTANT]: <|begin_of_thought|>\n"
)
def generate(question: str, max_new_tokens: int = MAX_NEW_TOKENS) -> str:
prompt = build_prompt(question)
input_ids = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
input_ids = input_ids.to(model.device)
prompt_len = input_ids.shape[1]
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.3,
top_k=25,
top_p=0.8,
repetition_penalty=1.3,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response_ids = output_ids[0][prompt_len:]
raw = tokenizer.decode(response_ids, skip_special_tokens=False).strip()
raw = raw.replace("", "").replace("", "").strip()
return "<|begin_of_thought|>\n" + raw
if __name__ == "__main__":
print("\n[+] Ready. Type 'quit' to exit.\n")
while True:
question = input("Question: ").strip()
if question.lower() == "quit":
break
print("=" * 50)
print(generate(question))
print()