""" app.py ------ Gradio web UI for live side-by-side comparison of baseline vs pruned vs pruned+quantized DistilBERT sentiment models. Usage: python app.py Then open http://localhost:7860 in your browser. """ import sys import os import time # --- Monkeypatch for Gradio 4.x / HuggingFace Hub 0.30+ incompatibility --- import huggingface_hub if not hasattr(huggingface_hub, 'HfFolder'): class HfFolder: @staticmethod def get_token(): return None huggingface_hub.HfFolder = HfFolder # -------------------------------------------------------------------------- import torch import psutil import pandas as pd import gradio as gr sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) from load_model import ( load_baseline_model, load_tokenizer, load_pruned_model, load_pruned_quantized_model, get_model_size_mb, PRUNED_SAVE_PATH, PRUNED_QUANTIZED_SAVE_PATH ) from evaluate import predict_single # ── Load models once at startup ────────────────────────────────────────── print("[*] Loading models for Gradio app...") tokenizer = load_tokenizer() baseline_model = load_baseline_model() baseline_size = get_model_size_mb(model=baseline_model) try: pruned_model = load_pruned_model() pruned_available = True pruned_size = get_model_size_mb(path=PRUNED_SAVE_PATH) except FileNotFoundError as e: pruned_available = False pruned_size = None print(f"[!] Pruned model not found: {e}") try: pq_model = load_pruned_quantized_model() pq_available = True pq_size = get_model_size_mb(path=PRUNED_QUANTIZED_SAVE_PATH) except FileNotFoundError as e: pq_available = False pq_size = None print(f"[!] Pruned+Quantized model not found: {e}") print("[+] Models ready.") # Sample sentences for the batch race (simulating a large batch by repeating) TEST_SENTENCES = [ "This movie was absolutely fantastic, I loved every minute of it.", "The food was terrible and the service was even worse.", "An average experience, nothing special but nothing terrible either.", "One of the best books I have ever read in my entire life.", "Complete waste of time and money. Would not recommend." ] * 100 # 500 total reviews # ── Inference function (Single) ──────────────────────────────────────────────── def run_comparison(text, simulate_edge): try: if not text.strip(): return ("⚠️ Please enter text.", "", "", "") * 3 + ("",) target_threads = 1 if simulate_edge else (os.cpu_count() or 4) if torch.get_num_threads() != target_threads: torch.set_num_threads(target_threads) # 1. Baseline b_label, b_conf, b_time = predict_single(baseline_model, tokenizer, text) b_emoji = "😊 POSITIVE" if b_label == "POSITIVE" else "😞 NEGATIVE" b_conf_str = f"{round(b_conf * 100, 1)}%" b_time_str = f"{b_time} ms" b_size_str = f"{baseline_size} MB" # 2. Pruned if pruned_available: p_label, p_conf, p_time = predict_single(pruned_model, tokenizer, text) p_emoji = "😊 POSITIVE" if p_label == "POSITIVE" else "😞 NEGATIVE" p_conf_str = f"{round(p_conf * 100, 1)}%" p_time_str = f"{p_time} ms" p_size_str = f"{pruned_size} MB" else: p_emoji = p_conf_str = p_time_str = p_size_str = "N/A" # 3. Pruned + Quantized if pq_available: q_label, q_conf, q_time = predict_single(pq_model, tokenizer, text) q_emoji = "😊 POSITIVE" if q_label == "POSITIVE" else "😞 NEGATIVE" q_conf_str = f"{round(q_conf * 100, 1)}%" q_time_str = f"{q_time} ms" q_size_str = f"{pq_size} MB" speedup = round(b_time / q_time, 2) if q_time > 0 else "N/A" size_reduction = round(((baseline_size - pq_size) / baseline_size) * 100, 1) summary = ( f"⚡ Final model is **{speedup}x faster** | " f"💾 **{size_reduction}% smaller** | " f"🎯 Predictions Match? **{'Yes' if b_label == q_label else 'No!'}**" ) else: q_emoji = q_conf_str = q_time_str = q_size_str = "N/A" summary = "⚠️ Pruned+Quantized model not found." return ( b_emoji, b_conf_str, b_time_str, b_size_str, p_emoji, p_conf_str, p_time_str, p_size_str, q_emoji, q_conf_str, q_time_str, q_size_str, summary ) except Exception as e: err = f"Error: {str(e)}" return (err, "", "", "") * 3 + (err,) # ── Inference functions (Race - Synchronous) ─────────────────────────────────── def run_batch(model, simulate_edge): try: target_threads = 1 if simulate_edge else min(os.cpu_count() or 4, 4) # Cap at 4 to prevent Docker host thread bomb if torch.get_num_threads() != target_threads: torch.set_num_threads(target_threads) inputs = [tokenizer(t, return_tensors="pt", truncation=True, max_length=128, padding=True) for t in TEST_SENTENCES] # Reset the CPU counter right before inference starts psutil.cpu_percent(interval=None) start_time = time.time() for batch_input in inputs: with torch.no_grad(): _ = model(**batch_input) total_time = round(time.time() - start_time, 2) # This will calculate the average CPU usage since the last call (which was right before the loop) cpu_usage = psutil.cpu_percent(interval=None) return f"✅ **Done!** ({total_time} seconds)\n\n📈 CPU Usage during run: **{cpu_usage}%**" except Exception as e: return f"❌ Error: {str(e)}" def race_baseline(simulate_edge): yield "🔄 Running..." yield run_batch(baseline_model, simulate_edge) def race_pruned(simulate_edge): if not pruned_available: yield "N/A"; return yield "🔄 Running..." yield run_batch(pruned_model, simulate_edge) def race_quantized(simulate_edge): if not pq_available: yield "N/A"; return yield "🔄 Running..." yield run_batch(pq_model, simulate_edge) # ── Gradio UI ────────────────────────────────────────────────────────────────── with gr.Blocks( title="BERT Compression POC", theme=gr.themes.Base( primary_hue="blue", neutral_hue="slate", font=[gr.themes.GoogleFont("IBM Plex Mono"), "monospace"] ), css=""" .container { max-width: 1000px; margin: auto; } .header { text-align: center; padding: 20px 0 10px 0; } .summary-box { background: #0ea5e9; color: white; border-radius: 8px; padding: 12px 16px; margin-top: 12px; } .race-box { border: 2px solid #334155; border-radius: 8px; padding: 20px; text-align: center; font-size: 1.2rem; } footer { display: none !important; } /* Force Dark Mode styling for the entire container to prevent light mode blinding */ body.gradio-container, .gradio-container { background-color: #0b0f19 !important; color: #f1f5f9 !important; } """ ) as demo: gr.HTML("""

🔬 Edge Model Compression Pipeline POC

Comparing Distillation vs Pruning vs Quantization

""") with gr.Tabs(): # TAB 1: Single Inference with gr.Tab("Single Inference"): with gr.Row(): text_input = gr.Textbox( placeholder="Type a sentence to analyse sentiment...", label="Input Text", lines=2, scale=4 ) with gr.Column(scale=1): simulate_edge = gr.Checkbox(label="🔌 Simulate Edge Device (1 CPU Core)", value=False) run_btn = gr.Button("▶ Run Inference", variant="primary") with gr.Row(): with gr.Column(): gr.HTML("

1. Baseline (DistilBERT)

") b_pred = gr.Textbox(label="Prediction", interactive=False) b_conf = gr.Textbox(label="Confidence", interactive=False) b_time = gr.Textbox(label="Inference Time", interactive=False) b_size = gr.Textbox(label="Model Size", interactive=False) with gr.Column(): gr.HTML("

2. Pruned (20% Sparsity)

") p_pred = gr.Textbox(label="Prediction", interactive=False) p_conf = gr.Textbox(label="Confidence", interactive=False) p_time = gr.Textbox(label="Inference Time", interactive=False) p_size = gr.Textbox(label="Model Size", interactive=False) with gr.Column(): gr.HTML("

3. Pruned + Quantized (INT8)

") q_pred = gr.Textbox(label="Prediction", interactive=False) q_conf = gr.Textbox(label="Confidence", interactive=False) q_time = gr.Textbox(label="Inference Time", interactive=False) q_size = gr.Textbox(label="Model Size", interactive=False) summary = gr.Markdown(value="*Run pipeline to see the comparison summary.*") outputs_single = [b_pred, b_conf, b_time, b_size, p_pred, p_conf, p_time, p_size, q_pred, q_conf, q_time, q_size, summary] run_btn.click(fn=run_comparison, inputs=[text_input, simulate_edge], outputs=outputs_single) # TAB 2: The Edge Race with gr.Tab("🚀 The Edge Race"): gr.Markdown("### Processing 500 Reviews (Batch Inference)\n*See how each model performs under heavy load.*") simulate_edge_race = gr.Checkbox(label="🔌 Simulate Edge Device (1 CPU Core) - Highly Recommended for Demo!", value=True) with gr.Row(): with gr.Column(elem_classes="race-box"): gr.HTML("

📦 Baseline

") race_b = gr.Markdown("Waiting to start...") btn_b = gr.Button("🏁 Run Baseline Race", variant="primary") with gr.Column(elem_classes="race-box"): gr.HTML("

✂️ Pruned

") race_p = gr.Markdown("Waiting to start...") btn_p = gr.Button("🏁 Run Pruned Race", variant="secondary") with gr.Column(elem_classes="race-box"): gr.HTML("

⚡ Pruned+Quantized

") race_q = gr.Markdown("Waiting to start...") btn_q = gr.Button("🏁 Run Quantized Race", variant="secondary") btn_b.click(fn=race_baseline, inputs=[simulate_edge_race], outputs=race_b) btn_p.click(fn=race_pruned, inputs=[simulate_edge_race], outputs=race_p) btn_q.click(fn=race_quantized, inputs=[simulate_edge_race], outputs=race_q) if __name__ == "__main__": demo.launch()