#!/usr/bin/env python3 """Interactive chat for Education-SLM.""" import torch from tokenizers import Tokenizer from config import cfg from model import IndustrySLM def load_model(checkpoint_name="best_model.pt"): device = torch.device(cfg.device) ckpt = torch.load(cfg.checkpoint_dir / checkpoint_name, map_location=device, weights_only=False) for k, v in ckpt.get("config", {}).items(): if hasattr(cfg, k): setattr(cfg, k, v) model = IndustrySLM() model.load_state_dict(ckpt["model_state_dict"], strict=False) model.to(device).eval() tokenizer = Tokenizer.from_file(str(cfg.tokenizer_dir / cfg.tokenizer_filename)) print(f"Model loaded: {model.count_parameters()/1e6:.2f}M params") return model, tokenizer, device def generate_response(model, tokenizer, device, prompt, max_tokens=None, temperature=0.8, top_k=50, top_p=0.9): ids = tokenizer.encode(prompt).ids if ids and ids[-1] == 3: ids = ids[:-1] input_ids = torch.tensor([ids], dtype=torch.long, device=device) with torch.no_grad(): out = model.generate(input_ids, max_new_tokens=max_tokens or cfg.max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p) return tokenizer.decode(out[0][input_ids.shape[1]:].tolist()).replace("","").replace("","").replace("","").strip() DEMO_PROMPTS = [ "Bloom's taxonomy classifies learning objectives into", "The flipped classroom model works by", "Formative assessment helps teachers to", "E-learning platforms have transformed education by", "Differentiated instruction addresses", ] def demo_generation(model, tokenizer, device): print(f"\n{'='*60}\nDemo: {cfg.domain_name}-SLM Inference\n{'='*60}\n") for i, p in enumerate(DEMO_PROMPTS, 1): print(f"[{i}] Prompt: {p}") print(f" Response: {generate_response(model, tokenizer, device, p, 256)}\n") def interactive_chat(): model, tokenizer, device = load_model() print(f"\n{cfg.domain_name}-SLM Chat (type 'quit' to exit, 'demo' for demos)\n") while True: try: u = input("You: ").strip() if not u: continue if u.lower() == "quit": break if u.lower() == "demo": demo_generation(model, tokenizer, device); continue print(f"{cfg.domain_name}-SLM: {generate_response(model, tokenizer, device, u)}\n") except KeyboardInterrupt: break if __name__ == "__main__": import sys if len(sys.argv) > 1 and sys.argv[1] == "demo": m,t,d = load_model(); demo_generation(m,t,d) else: interactive_chat()