slm-bahasa-id / generate.py
romizone's picture
Upload SLM Bahasa Indonesia
9815efc verified
Raw
History Blame
2.88 kB
"""
Text Generation - Small Language Model
=======================================
Author: Jekardah AI Lab
Usage:
python generate.py
python generate.py --prompt "indonesia adalah"
python generate.py --interactive
"""
import os
import sys
import argparse
import torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "bpe-tokenizer-id"))
from model import SmallLM
from bpe_tokenizer import BPETokenizer
def load_model(model_dir="./output_slm"):
model = SmallLM.from_pretrained(model_dir)
tokenizer = BPETokenizer.from_pretrained(model_dir)
return model, tokenizer
def generate(model, tokenizer, prompt, max_tokens=50, temperature=0.8, top_k=40):
ids = tokenizer.encode(prompt.lower())
inp = torch.tensor([ids], dtype=torch.long)
out = model.generate(inp, max_new_tokens=max_tokens,
temperature=temperature, top_k=top_k)
return tokenizer.decode(out[0].tolist())
def demo(model, tokenizer):
print("\n" + "=" * 60)
print("๐Ÿ‡ฎ๐Ÿ‡ฉ SLM Bahasa Indonesia - Demo")
print("=" * 60)
prompts = [
"indonesia adalah",
"pendidikan",
"jakarta",
"makan nasi",
"teknologi",
"kebudayaan indonesia",
"demokrasi",
"ekonomi",
"saya suka",
"hutan",
]
for p in prompts:
text = generate(model, tokenizer, p, max_tokens=30)
print(f"\n๐Ÿ“ \"{p}\"")
print(f"๐Ÿค– {text[:80]}")
print("\n" + "=" * 60)
def interactive(model, tokenizer):
print("\n" + "=" * 60)
print("๐Ÿ‡ฎ๐Ÿ‡ฉ SLM Bahasa Indonesia - Interactive")
print("=" * 60)
print("Ketik awalan kalimat, model melanjutkan.")
print("Ketik 'quit' untuk keluar.\n")
while True:
try:
prompt = input("๐Ÿ“ Input: ").strip()
except (EOFError, KeyboardInterrupt):
break
if prompt.lower() in ('quit', 'exit', 'q'):
break
if not prompt:
continue
text = generate(model, tokenizer, prompt, max_tokens=50)
print(f"๐Ÿค– {text}\n")
print("๐Ÿ‘‹ Bye!")
def main():
parser = argparse.ArgumentParser(description="SLM Generator")
parser.add_argument("--model-dir", default="./output_slm")
parser.add_argument("--prompt", type=str)
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-tokens", type=int, default=50)
parser.add_argument("--temperature", type=float, default=0.8)
args = parser.parse_args()
model, tokenizer = load_model(args.model_dir)
if args.interactive:
interactive(model, tokenizer)
elif args.prompt:
print(generate(model, tokenizer, args.prompt,
max_tokens=args.max_tokens, temperature=args.temperature))
else:
demo(model, tokenizer)
if __name__ == "__main__":
main()