| """ |
| Text Generation - Small Language Model |
| ======================================= |
| Author: Jekardah AI Lab |
| |
| Usage: |
| python generate.py |
| python generate.py --prompt "indonesia adalah" |
| python generate.py --interactive |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import torch |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "bpe-tokenizer-id")) |
| from model import SmallLM |
| from bpe_tokenizer import BPETokenizer |
|
|
|
|
| def load_model(model_dir="./output_slm"): |
| model = SmallLM.from_pretrained(model_dir) |
| tokenizer = BPETokenizer.from_pretrained(model_dir) |
| return model, tokenizer |
|
|
|
|
| def generate(model, tokenizer, prompt, max_tokens=50, temperature=0.8, top_k=40): |
| ids = tokenizer.encode(prompt.lower()) |
| inp = torch.tensor([ids], dtype=torch.long) |
| out = model.generate(inp, max_new_tokens=max_tokens, |
| temperature=temperature, top_k=top_k) |
| return tokenizer.decode(out[0].tolist()) |
|
|
|
|
| def demo(model, tokenizer): |
| print("\n" + "=" * 60) |
| print("๐ฎ๐ฉ SLM Bahasa Indonesia - Demo") |
| print("=" * 60) |
|
|
| prompts = [ |
| "indonesia adalah", |
| "pendidikan", |
| "jakarta", |
| "makan nasi", |
| "teknologi", |
| "kebudayaan indonesia", |
| "demokrasi", |
| "ekonomi", |
| "saya suka", |
| "hutan", |
| ] |
|
|
| for p in prompts: |
| text = generate(model, tokenizer, p, max_tokens=30) |
| print(f"\n๐ \"{p}\"") |
| print(f"๐ค {text[:80]}") |
|
|
| print("\n" + "=" * 60) |
|
|
|
|
| def interactive(model, tokenizer): |
| print("\n" + "=" * 60) |
| print("๐ฎ๐ฉ SLM Bahasa Indonesia - Interactive") |
| print("=" * 60) |
| print("Ketik awalan kalimat, model melanjutkan.") |
| print("Ketik 'quit' untuk keluar.\n") |
|
|
| while True: |
| try: |
| prompt = input("๐ Input: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| break |
| if prompt.lower() in ('quit', 'exit', 'q'): |
| break |
| if not prompt: |
| continue |
|
|
| text = generate(model, tokenizer, prompt, max_tokens=50) |
| print(f"๐ค {text}\n") |
|
|
| print("๐ Bye!") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="SLM Generator") |
| parser.add_argument("--model-dir", default="./output_slm") |
| parser.add_argument("--prompt", type=str) |
| parser.add_argument("--interactive", action="store_true") |
| parser.add_argument("--max-tokens", type=int, default=50) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| args = parser.parse_args() |
|
|
| model, tokenizer = load_model(args.model_dir) |
|
|
| if args.interactive: |
| interactive(model, tokenizer) |
| elif args.prompt: |
| print(generate(model, tokenizer, args.prompt, |
| max_tokens=args.max_tokens, temperature=args.temperature)) |
| else: |
| demo(model, tokenizer) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|