import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import gc import os import shutil import torch import psutil # Define path for HF cache to clean HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub") # List of models for autocomplete MODELS = [ 'HuggingFaceTB/SmolLM2-135M', 'AxiomicLabs/GPT-X2-125M', 'AxiomicLabs/GPT-X-125M', 'facebook/MobileLLM-R1-140M-base', 'SupraLabs/Supra-50M-Base', 'CompactAI-O/Shard-1', 'SupraLabs/Supra-50M-Instruct', 'HuggingFaceTB/SmolLM-135M', 'facebook/opt-125m', 'AxiomicLabs/GPT-S-5M', 'openai-community/gpt2', 'LH-Tech-AI/Spark-5M-Base-v4', 'SupraLabs/Supra-Mini-v5-8M', 'EleutherAI/pythia-70m', 'SupraLabs/Supra-Mini-v4-2M', 'EleutherAI/pythia-31m', 'StentorLabs/Stentor3-50M', 'StentorLabs/Stentor3-20M', 'StentorLabs/Portimbria-150M', 'HuggingFaceTB/nanowhale-100m-base', 'EleutherAI/pythia-14m', 'Harley-ml/Tenete-8M', 'Harley-ml/Dillion-1.2M', 'MihaiPopa-1/CinnabarLM-1.4M-Base', 'MihaiPopa-1/CinnabarLM-4M-Base', 'MihaiPopa-1/PotentSulfurLM-500K-Base', 'MihaiPopa-1/CinnabarLM-1.5M-Base', 'Harley-ml/Dillionv2-1.3M', 'Eclipse-Senpai/KeyLM-75M', 'SupraLabs/Supra-Mini-v6-1M', 'AxiomicLabs/GPT-S-1.4M', 'GODELEV/Archaea-74M', 'Sandroeth/cali-0.1B', 'veyra-ai/veyra3-5m-base', 'veyra-ai/veyra-30m-base-5b-tokens', 'ThingAI/Quark-50m', 'ThingAI/Quark-135m' ] # Global class to safely manage the loaded model and tokenizer in memory class ModelManager: def __init__(self): self.model = None self.tokenizer = None self.device = "cuda" if torch.cuda.is_available() else "cpu" model_manager = ModelManager() def get_system_stats(): """Returns a dictionary of current system metrics with formatted strings.""" mem = psutil.virtual_memory() disk = psutil.disk_usage('/') return { "CPU": f"{psutil.cpu_percent(interval=1)}%", "Memory": f"{round(mem.used / (1024**3), 2)} / {round(mem.total / (1024**3), 2)} GB", "Disk": f"{round(disk.used / (1024**3), 2)} / {round(disk.total / (1024**3), 2)} GB" } def load_new_model(model_id): """Loads the model and tokenizer dynamically into the global manager.""" # Clear old model from memory model_manager.model = None model_manager.tokenizer = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() try: # Load explicitly for streaming purposes instead of pipeline tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(model_manager.device) model_manager.tokenizer = tokenizer model_manager.model = model return f"Successfully loaded {model_id} on {model_manager.device.upper()}" except Exception as e: return f"Error loading model: {str(e)}" def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalty, ngram_size, do_sample): """Generates text via streaming generator.""" if model_manager.model is None or model_manager.tokenizer is None: yield "Please load a model first." return tokenizer = model_manager.tokenizer model = model_manager.model # Tokenize input inputs = tokenizer([user_prompt], return_tensors="pt").to(model_manager.device) # Set up the streamer streamer = TextIteratorStreamer(tokenizer, timeout=15.0, skip_prompt=True, skip_special_tokens=True) # Adjust variables based on the do_sample logic if not do_sample: temperature = 1.0 # Temperature is ignored if do_sample=False, but setting it > 0 avoids config errors # Generation arguments generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=int(max_tokens), temperature=float(temperature), top_k=int(top_k), top_p=float(top_p), repetition_penalty=float(rep_penalty), no_repeat_ngram_size=int(ngram_size), do_sample=do_sample, pad_token_id=tokenizer.eos_token_id # Prevents padding warnings ) # Start generation in a separate background thread thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() # Yield output iteratively for the streaming effect generated_text = user_prompt for new_text in streamer: generated_text += new_text yield generated_text def clean_cache(): if os.path.exists(HF_CACHE_DIR): shutil.rmtree(HF_CACHE_DIR) os.makedirs(HF_CACHE_DIR) return "Cache cleaned successfully!" return "Cache directory not found." # Gradio Interface with gr.Blocks(title="Small MF Model Tester", theme=gr.themes.Soft()) as app: gr.Markdown("# 🚀 Small Model Evaluation Hub with Streaming") with gr.Row(): # Left column: Settings & Monitoring with gr.Column(scale=1): with gr.Accordion("System Monitoring", open=True): stats_output = gr.JSON(label="Live System Stats") gr.Timer(2).tick(get_system_stats, None, stats_output) with gr.Group(): gr.Markdown("### Model Loader") with gr.Row(): model_id_input = gr.Dropdown(choices=MODELS, label="Model", allow_custom_value=True, show_label=False, scale=3) load_btn = gr.Button("Load", variant="secondary", scale=1) status_output = gr.Markdown("Status: *Waiting to load model...*") clean_btn = gr.Button("Clean HF Cache", variant="stop", size="sm") with gr.Accordion("Generation Configuration", open=False): do_sample_input = gr.Checkbox(label="Enable Sampling (do_sample)", value=True, info="Uncheck for greedy decoding") max_tokens_input = gr.Slider(minimum=10, maximum=2048, value=128, step=1, label="Max Output Tokens") temperature_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative") gr.Markdown("#### Advanced Sampling Constraints") top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="0 = disabled") top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)", info="1.0 = disabled") rep_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="1.0 = disabled") ngram_size_input = gr.Slider(minimum=0, maximum=10, value=0, step=1, label="No Repeat N-Gram Size", info="0 = disabled") # Right column: Interaction with gr.Column(scale=2): user_prompt = gr.Textbox( label="Prompt", value="Once upon a time in a digital kingdom,", placeholder="Enter your prompt here...", lines=5 ) run_btn = gr.Button("Generate text", variant="primary", size="lg") output_text = gr.Textbox(label="Result", lines=15, buttons=["copy"], autoscroll=True) # Events load_btn.click( fn=load_new_model, inputs=[model_id_input], outputs=[status_output] ) # We use `.click` targeting a generator function, which Gradio naturally treats as a streaming output run_btn.click( fn=run_inference, inputs=[ user_prompt, max_tokens_input, temperature_input, top_k_input, top_p_input, rep_penalty_input, ngram_size_input, do_sample_input ], outputs=[output_text] ) clean_btn.click(fn=clean_cache, outputs=[status_output]) if __name__ == "__main__": app.launch()