| |
| |
| |
| |
| |
| |
| |
|
|
| import csv |
| import os |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import nltk |
| nltk.download('stopwords', quiet=True) |
|
|
| import eli5 |
| import gradio as gr |
| import lime |
| import lime.lime_tabular |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import shap |
| import joblib |
| from scipy.sparse import hstack, csr_matrix |
|
|
| from utils import (preprocess_text, compute_metadata_features, |
| META_FEATURE_NAMES, FEATURE_DESCRIPTIONS) |
|
|
| |
| |
| |
|
|
| models_dir = Path(__file__).parent / 'models' |
| feedback_dir = Path(__file__).parent / 'feedback' |
| feedback_dir.mkdir(exist_ok=True) |
| FEEDBACK_CSV = feedback_dir / 'feedback_log.csv' |
|
|
| try: |
| voting_model = joblib.load(models_dir / 'voting_model.joblib') |
| tfidf_vectorizer = joblib.load(models_dir / 'tfidf_vectorizer.joblib') |
| meta_scaler = joblib.load(models_dir / 'meta_scaler.joblib') |
| feature_names = joblib.load(models_dir / 'feature_names.joblib') |
| optimal_threshold = joblib.load(models_dir / 'optimal_threshold.joblib') |
| training_sample = joblib.load(models_dir / 'training_sample.joblib') |
| |
| if hasattr(voting_model, 'named_estimators_'): |
| raw_rf = voting_model.named_estimators_['rf'] |
| else: |
| raw_rf = voting_model |
| print(f"All models loaded. Threshold = {optimal_threshold:.4f}") |
| except FileNotFoundError as e: |
| print(f"Model file not found: {e}") |
| voting_model = None |
| tfidf_vectorizer = None |
| meta_scaler = None |
| feature_names = None |
| optimal_threshold = None |
| training_sample = None |
| raw_rf = None |
|
|
| |
| |
| |
|
|
| lime_explainer = None |
| if training_sample is not None and feature_names is not None: |
| lime_explainer = lime.lime_tabular.LimeTabularExplainer( |
| training_data=training_sample, |
| feature_names=feature_names, |
| class_names=['Ham', 'Spam'], |
| mode='classification', |
| ) |
| print("LIME explainer ready.") |
|
|
| |
| |
| |
|
|
| def classify_email(email_text, threshold): |
| """Classify a single email. Returns (label, confidence, spam_proba, combined_features).""" |
| cleaned_text = preprocess_text(email_text) |
| tfidf_features = tfidf_vectorizer.transform([cleaned_text]) |
| meta_raw = compute_metadata_features([email_text]) |
| meta_scaled = meta_scaler.transform(meta_raw) |
| combined = hstack([tfidf_features, csr_matrix(meta_scaled)]) |
| spam_proba = voting_model.predict_proba(combined)[0][1] |
|
|
| if spam_proba >= threshold: |
| label = "SPAM" |
| confidence = spam_proba |
| else: |
| label = "HAM (Not Spam)" |
| confidence = 1.0 - spam_proba |
|
|
| return label, confidence, spam_proba, combined |
|
|
| |
| |
| |
|
|
| def generate_lime_explanation(combined_features): |
| """Generate LIME explanation. Returns (figure, explanation) or (None, None).""" |
| if lime_explainer is None: |
| return None, None |
| instance = combined_features.toarray()[0] |
| explanation = lime_explainer.explain_instance( |
| instance, |
| voting_model.predict_proba, |
| num_features=10, |
| ) |
| fig = explanation.as_pyplot_figure() |
| fig.tight_layout() |
| return fig, explanation |
|
|
| |
| |
| |
|
|
| |
| |
| def predict_with_meta_only(meta_features, num_tfidf, model): |
| |
| n_samples = meta_features.shape[0] |
| |
| tfidf_zeros = csr_matrix((n_samples, num_tfidf)) |
| |
| combined = hstack([tfidf_zeros, csr_matrix(meta_features)]) |
| return model.predict_proba(combined) |
|
|
|
|
| def generate_shap_explanation(email_text): |
| """Generate SHAP bar chart for metadata features. Returns (figure, shap_values, top_indices) or (None, None, None).""" |
| if training_sample is None or voting_model is None: |
| return None, None, None |
|
|
| num_meta = len(META_FEATURE_NAMES) |
| background_meta = training_sample[:50, -num_meta:] |
| meta_raw = compute_metadata_features([email_text]) |
| meta_scaled = meta_scaler.transform(meta_raw) |
| num_tfidf = training_sample.shape[1] - num_meta |
|
|
| |
| def shap_predict(meta_features): |
| return predict_with_meta_only(meta_features, num_tfidf, voting_model) |
|
|
| explainer = shap.KernelExplainer(shap_predict, background_meta) |
| shap_values = explainer.shap_values(meta_scaled, nsamples=100) |
|
|
| if isinstance(shap_values, list): |
| sv = np.array(shap_values[1]).flatten() |
| else: |
| sv = np.array(shap_values).flatten() |
| if len(sv) > num_meta: |
| sv = sv[-num_meta:] |
|
|
| top_idx = np.argsort(np.abs(sv))[::-1][:10] |
|
|
| sorted_indices = np.argsort(np.abs(sv)) |
| sorted_names = [META_FEATURE_NAMES[idx] for idx in sorted_indices.tolist()] |
| sorted_values = sv[sorted_indices] |
|
|
| fig, ax = plt.subplots(figsize=(8, 6)) |
| colors = ['#d62728' if val > 0 else '#1f77b4' for val in sorted_values] |
| ax.barh(sorted_names, sorted_values, color=colors) |
| ax.set_xlabel('SHAP Value (impact on spam probability)') |
| ax.set_title('SHAP Feature Importance (Metadata Features)') |
| ax.axvline(x=0, color='black', linewidth=0.5) |
| fig.tight_layout() |
|
|
| return fig, sv, top_idx |
|
|
| |
| |
| |
|
|
| def generate_eli5_explanation(combined_features): |
| """Generate ELI5 HTML and top feature names. Returns (html_string, feature_names_list) or (None, None).""" |
| if raw_rf is None or feature_names is None: |
| return None, None |
|
|
| instance = combined_features.toarray()[0] |
|
|
| eli5_exp = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=10) |
| raw_html = eli5.format_as_html(eli5_exp) |
| |
| |
| |
| import re |
| clean_html = re.sub(r'<pre>.*?</pre>', '', raw_html, flags=re.DOTALL) |
| html = f'<div style="overflow-wrap: break-word; word-break: break-word;">{clean_html}</div>' |
|
|
| eli5_top5 = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=5) |
| top_names = [] |
| if hasattr(eli5_top5, 'targets') and eli5_top5.targets: |
| for fw in eli5_top5.targets[0].feature_weights.pos[:5]: |
| top_names.append(fw.feature) |
| for fw in eli5_top5.targets[0].feature_weights.neg[:5]: |
| top_names.append(fw.feature) |
|
|
| return html, top_names |
|
|
| |
| |
| |
|
|
| |
| def get_result_badge(label, confidence): |
| """Return a dict with color, icon, and text describing the classification result.""" |
| |
| badge_color = "" |
| badge_icon = "" |
| badge_text = "" |
|
|
| |
| |
| if "SPAM" in label: |
| if confidence > 0.9: |
| badge_color = "red" |
| badge_icon = "π¨" |
| badge_text = "The model is highly confident this email contains patterns commonly seen in spam or phishing attempts." |
| elif confidence > 0.7: |
| badge_color = "orange" |
| badge_icon = "β οΈ" |
| badge_text = "The model found several spam-like patterns in this email." |
| else: |
| badge_color = "yellow" |
| badge_icon = "β" |
| badge_text = "The model leans toward spam, but the evidence is not overwhelming. Use your judgment." |
| else: |
| |
| if confidence > 0.9: |
| badge_color = "green" |
| badge_icon = "β
" |
| badge_text = "The model is highly confident this is a legitimate email." |
| elif confidence > 0.7: |
| badge_color = "blue" |
| badge_icon = "βΉοΈ" |
| badge_text = "The model found this email to be mostly consistent with legitimate messages." |
| else: |
| badge_color = "gray" |
| badge_icon = "β" |
| badge_text = "The model leans toward legitimate, but there are some spam-like features. Review carefully." |
|
|
| |
| result = {"color": badge_color, "icon": badge_icon, "text": badge_text} |
| return result |
|
|
|
|
| def generate_plain_summary(label, confidence, spam_proba, lime_explanation, |
| shap_sv, shap_top_idx): |
| """Build a rule-based plain English summary from XAI results.""" |
| summary = f"### Classification: **{label}** ({confidence:.0%} confidence)\n\n" |
|
|
| if lime_explanation is not None: |
| feature_list = lime_explanation.as_list() |
| summary += "**Key words driving this decision (LIME):**\n" |
| for feat_rule, weight in feature_list[:3]: |
| direction = "pushes toward spam" if weight > 0 else "pushes toward ham" |
| summary += f"- **{feat_rule}** β {direction}\n" |
| summary += "\n" |
|
|
| if shap_sv is not None and shap_top_idx is not None: |
| summary += "**Important email characteristics (SHAP):**\n" |
| for i in shap_top_idx[:2]: |
| feat_name = META_FEATURE_NAMES[i] |
| description = FEATURE_DESCRIPTIONS.get(feat_name, feat_name) |
| direction = "spam signal" if shap_sv[i] > 0 else "ham signal" |
| summary += f"- **{feat_name}** ({description}) β {direction}\n" |
| summary += "\n" |
|
|
| if lime_explanation is not None and shap_top_idx is not None: |
| lime_top = set(f[0] for f in lime_explanation.as_list()[:10]) |
| shap_top = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10]) |
| overlap = lime_top & shap_top |
| if overlap: |
| summary += f"**Method agreement:** LIME and SHAP both flag: {', '.join(sorted(overlap))}\n\n" |
|
|
| |
| badge = get_result_badge(label, confidence) |
| summary += badge["text"] |
|
|
| return summary |
|
|
| |
| |
| |
|
|
| def get_top_features(explanation, method_name): |
| """Return a list of up to 3 plain feature name strings from an explanation object. |
| |
| Each XAI tool returns a different object type, so we use method_name to |
| decide how to extract the names. |
| |
| Parameters |
| ---------- |
| explanation : the explanation object (LIME Explanation, numpy array, or list) |
| method_name : one of "lime", "shap", or "eli5" |
| |
| Returns |
| ------- |
| A plain Python list of feature name strings (up to 3 items). |
| """ |
| |
| feature_names = [] |
|
|
| if method_name == "lime": |
| |
| if explanation is not None: |
| for feat, w in explanation.as_list()[:3]: |
| feature_names.append(feat) |
|
|
| elif method_name == "shap": |
| |
| |
| if explanation is not None: |
| shap_sv, shap_top_idx = explanation |
| if shap_sv is not None and shap_top_idx is not None: |
| for i in shap_top_idx[:3]: |
| feature_names.append(META_FEATURE_NAMES[i]) |
|
|
| elif method_name == "eli5": |
| |
| if explanation is not None: |
| for name in explanation[:3]: |
| feature_names.append(name) |
|
|
| return feature_names |
|
|
|
|
| def generate_comparison(lime_explanation, shap_sv, shap_top_idx, eli5_names): |
| """Build a markdown comparison of top features from each XAI method.""" |
| md = "### Side-by-Side: Top Features by Method\n\n" |
| md += "| Rank | LIME | SHAP (metadata) | ELI5 |\n" |
| md += "|------|------|-----------------|------|\n" |
|
|
| lime_top5 = [] |
| if lime_explanation is not None: |
| for feat, w in lime_explanation.as_list()[:5]: |
| direction = "spam" if w > 0 else "ham" |
| lime_top5.append(f"{feat} ({direction}, {w:+.3f})") |
|
|
| shap_top5 = [] |
| if shap_sv is not None and shap_top_idx is not None: |
| for i in shap_top_idx[:5]: |
| direction = "spam" if shap_sv[i] > 0 else "ham" |
| shap_top5.append(f"{META_FEATURE_NAMES[i]} ({direction}, {shap_sv[i]:+.3f})") |
|
|
| eli5_top5 = (eli5_names or [])[:5] |
|
|
| for rank in range(5): |
| lime_cell = lime_top5[rank] if rank < len(lime_top5) else "β" |
| shap_cell = shap_top5[rank] if rank < len(shap_top5) else "β" |
| eli5_cell = eli5_top5[rank] if rank < len(eli5_top5) else "β" |
| md += f"| {rank+1} | {lime_cell} | {shap_cell} | {eli5_cell} |\n" |
|
|
| |
| lime_names = get_top_features(lime_explanation, "lime") |
| shap_names = get_top_features((shap_sv, shap_top_idx), "shap") |
| eli5_names_short = get_top_features(eli5_names, "eli5") |
|
|
| |
| overlap = [] |
| for name in lime_names: |
| if name in shap_names: |
| if name in eli5_names_short: |
| overlap.append(name) |
|
|
| if lime_explanation is not None and shap_top_idx is not None: |
| lime_set = set(f[0] for f in lime_explanation.as_list()[:10]) |
| shap_set = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10]) |
| overlap_lime_shap = lime_set & shap_set |
| md += f"\n**LIME-SHAP agreement** (top 10): **{len(overlap_lime_shap)}** shared features" |
| if overlap_lime_shap: |
| md += f"\nShared: {', '.join(sorted(overlap_lime_shap))}" |
|
|
| md += "\n\n*Note: LIME covers all features (words + metadata), SHAP covers only the 24 metadata features, " |
| md += "ELI5 uses the Random Forest sub-estimator's internal weights.*" |
|
|
| return md |
|
|
| |
| |
| |
|
|
| def log_feedback(email_text, predicted_label, predicted_confidence, threshold, |
| feedback_type, correct_label=None): |
| """Append one feedback row to the CSV log.""" |
| write_header = not FEEDBACK_CSV.exists() |
| with open(FEEDBACK_CSV, 'a', newline='', encoding='utf-8') as f: |
| writer = csv.writer(f) |
| if write_header: |
| writer.writerow(['timestamp', 'email_text', 'predicted_label', |
| 'predicted_confidence', 'feedback', 'correct_label', |
| 'threshold_used']) |
| writer.writerow([ |
| datetime.now().isoformat(), |
| email_text[:500], |
| predicted_label, |
| f"{predicted_confidence:.4f}", |
| feedback_type, |
| correct_label or '', |
| f"{threshold:.4f}", |
| ]) |
| return count_corrections() |
|
|
|
|
| def count_corrections(): |
| """Count the number of 'wrong' entries in the feedback log.""" |
| if not FEEDBACK_CSV.exists(): |
| return 0 |
| count = 0 |
| with open(FEEDBACK_CSV, 'r', encoding='utf-8') as f: |
| reader = csv.DictReader(f) |
| for row in reader: |
| if row.get('feedback') == 'wrong': |
| count += 1 |
| return count |
|
|
| |
| |
| |
|
|
| EXAMPLE_EMAILS = [ |
| ["Subject: URGENT - You Have Won $5,000,000!!!\n\nDear Friend,\n\nCONGRATULATIONS!!! You have been selected as the winner of our international lottery program!!!\nTo claim your $5,000,000 USD prize, click the link below IMMEDIATELY and provide your bank details.\n\nACT NOW - This offer expires in 24 hours!!!\n\nClick here: http://totally-legit-prize.com/claim\nSend $500 processing fee to unlock your winnings.\n\nBest regards,\nDr. Prince Mohammed"], |
| ["Subject: Team sync Thursday 2pm\n\nHi everyone,\n\nJust a reminder that we have our weekly team sync this Thursday at 2pm in Conference Room B.\n\nAgenda:\n- Sprint review\n- Q2 planning discussion\n- New hire onboarding update\n\nPlease come prepared with your status updates.\n\nThanks,\nSarah"], |
| ["Subject: Your account has been compromised!\n\nDear Customer,\n\nWe detected suspicious activity on your account. Click here immediately to verify your identity: http://secure-bank-login.com/verify\n\nIf you do not verify within 24 hours, your account will be permanently locked.\n\nSecurity Team"], |
| ["Subject: Thanksgiving dinner plans\n\nHi everyone!\n\nI wanted to start planning for Thanksgiving dinner. I'm thinking we could do it at my place this year. What does everyone think about 4pm?\n\nLet me know if you have any dietary restrictions or if you want to bring a dish.\n\nLove,\nMom"], |
| ["Subject: Best prices on V1AGRA and C1ALIS!!!\n\n$$$ SAVE BIG $$$\nBuy now and get 80% OFF!!!\nNo prescription needed! Free shipping!\nOrder at http://cheap-pharma-deals.com\n\nLIMITED TIME OFFER - ACT NOW!"], |
| ] |
|
|
| |
| |
| |
|
|
| def classify_and_explain(email_text, uploaded_file, threshold): |
| """Main function called by Gradio. Returns all outputs for all tabs + feedback state.""" |
|
|
| |
| if uploaded_file is not None: |
| try: |
| file_content = Path(uploaded_file).read_text(encoding='utf-8') |
| email_text = file_content |
| except Exception: |
| empty = ("Could not read file.", None, None, "Error reading file.", "", "", "") |
| return empty |
|
|
| if email_text is None or email_text.strip() == '': |
| empty = ("Please enter email text or upload a file.", None, None, "", "", "", "") |
| return empty |
|
|
| if voting_model is None: |
| empty = ("Models not found. Run `python3 train.py` first.", None, None, "", "", "", "") |
| return empty |
|
|
| |
| label, confidence, spam_proba, combined = classify_email(email_text, threshold) |
|
|
| |
| lime_fig, lime_exp = generate_lime_explanation(combined) |
|
|
| |
| try: |
| shap_fig, shap_sv, shap_top_idx = generate_shap_explanation(email_text) |
| except Exception as e: |
| print(f"SHAP error: {e}") |
| shap_fig, shap_sv, shap_top_idx = None, None, None |
|
|
| |
| try: |
| eli5_html, eli5_names = generate_eli5_explanation(combined) |
| except Exception as e: |
| print(f"ELI5 error: {e}") |
| eli5_html, eli5_names = None, None |
|
|
| result_md = f"## {'SPAM' if 'SPAM' in label else 'HAM (Not Spam)'}\n\n" |
| result_md += f"**Confidence:** {confidence:.1%}\n\n" |
| result_md += f"**Threshold:** {threshold:.0%}\n\n" |
| result_md += f"**Spam probability:** {spam_proba:.1%}\n\n" |
| if lime_exp is not None: |
| result_md += "**Key factors:**\n" |
| for feat_rule, weight in lime_exp.as_list()[:5]: |
| direction = "pushes toward spam" if weight > 0 else "pushes toward ham" |
| result_md += f"- **{feat_rule}** {direction}\n" |
|
|
| comparison_md = generate_comparison(lime_exp, shap_sv, shap_top_idx, eli5_names) |
| summary_md = generate_plain_summary(label, confidence, spam_proba, lime_exp, shap_sv, shap_top_idx) |
| eli5_display = eli5_html or "<p>ELI5 explanation not available.</p>" |
|
|
| |
| |
| return (result_md, lime_fig, shap_fig, eli5_display, comparison_md, summary_md, |
| f"{label}|||{confidence:.4f}|||{threshold:.4f}|||{email_text[:500]}") |
|
|
| |
| |
| |
|
|
| def handle_feedback(hidden_state, is_correct, user_label=""): |
| |
| |
| |
|
|
| |
| if not hidden_state: |
| return "No classification to give feedback on." |
|
|
| |
| parts = hidden_state.split('|||') |
| if len(parts) < 4: |
| return "No classification to give feedback on." |
|
|
| |
| predicted_label = parts[0] |
| predicted_confidence = float(parts[1]) |
| threshold_used = float(parts[2]) |
| email_text = parts[3] |
|
|
| if is_correct: |
| |
| corrections = log_feedback(email_text, predicted_label, predicted_confidence, |
| threshold_used, 'correct') |
| return f"Thanks for the feedback! ({corrections} corrections collected so far)" |
| else: |
| |
| corrections = log_feedback(email_text, predicted_label, predicted_confidence, |
| threshold_used, 'wrong', user_label) |
| return f"Correction logged! ({corrections} corrections collected so far)" |
|
|
| |
| |
| |
|
|
| HOW_IT_WORKS_MD = """ |
| ## How This App Works |
| |
| ### What is spam classification? |
| Spam classification automatically identifies unwanted or malicious emails (spam) vs. legitimate messages (ham). This helps protect users from phishing scams, fraudulent offers, and unwanted advertising. |
| |
| ### The Model |
| This app uses a **Voting Ensemble** β three different machine learning models that each "vote" on whether an email is spam: |
| - **Random Forest** β builds many decision trees and takes the majority vote |
| - **Logistic Regression** β finds a mathematical boundary between spam and ham |
| - **Support Vector Machine (SVM)** β finds the widest possible margin between classes |
| |
| By combining all three, the ensemble is more accurate than any single model alone. |
| |
| ### Feature Extraction |
| The model looks at two types of features: |
| - **TF-IDF (Term Frequency-Inverse Document Frequency)** β measures how important each word is. Common spam words like "prize" or "click" get high scores. |
| - **24 Metadata Features** β structural patterns like exclamation mark density, dollar sign count, ALL CAPS ratio, URL count, and more. |
| |
| ### Explainable AI (XAI) Methods |
| This app doesn't just classify β it explains **why**: |
| |
| - **LIME** β Removes words one at a time and watches how the prediction changes. Shows which words matter most. |
| - **SHAP** β Uses game theory to calculate each feature's "fair share" of the prediction. Based on Nobel Prize-winning mathematics. |
| - **ELI5** β Looks directly at the model's internal weights to show which features it relies on most. |
| |
| ### Feedback & Retraining |
| When you click "Correct" or "Wrong", your feedback is saved. After enough corrections accumulate, the model can be retrained with the new examples to improve over time. This is called **human-in-the-loop machine learning**. |
| |
| ### Disclaimer |
| This model was created as a university course project. It is intended for **educational and research purposes only** and should not be used as a sole spam filter in production. Always use established email security tools for real-world spam filtering. |
| """ |
|
|
| theme = gr.themes.Soft( |
| primary_hue="blue", |
| secondary_hue="red", |
| neutral_hue="slate", |
| font=gr.themes.GoogleFont("Inter"), |
| font_mono=gr.themes.GoogleFont("IBM Plex Mono"), |
| ) |
|
|
| custom_css = """ |
| /* ββ Container ββ */ |
| .gradio-container { |
| max-width: 1600px !important; |
| margin: 0 auto !important; |
| padding: 1.5rem 2rem !important; |
| } |
| |
| /* ββ Top bar ββ */ |
| .topbar { |
| background: linear-gradient(135deg, #f8fafc 0%, #eef2ff 100%); |
| border: 1px solid #e2e8f0; |
| border-radius: 14px; |
| padding: 1.4rem 1.8rem 1.2rem; |
| margin-bottom: 1.2rem; |
| box-shadow: 0 1px 3px rgba(0,0,0,0.06); |
| text-align: center; |
| } |
| .topbar-title { |
| font-size: 22px; |
| font-weight: 700; |
| color: #1e293b; |
| margin: 0 0 0.3rem; |
| } |
| .topbar-subtitle { |
| font-size: 13px; |
| color: #64748b; |
| margin: 0 0 0.7rem; |
| } |
| .topbar-badges { |
| display: flex; |
| justify-content: center; |
| gap: 0.5rem; |
| flex-wrap: wrap; |
| } |
| .topbar-badge { |
| display: inline-block; |
| background: #e0e7ff; |
| color: #3730a3; |
| font-size: 11.5px; |
| font-weight: 600; |
| padding: 0.25rem 0.7rem; |
| border-radius: 999px; |
| letter-spacing: 0.02em; |
| } |
| |
| /* ββ Input panel (left column) ββ */ |
| .input-panel { |
| background: linear-gradient(180deg, #ffffff 0%, #f8fafc 100%); |
| border: 1px solid #e2e8f0; |
| border-radius: 14px; |
| padding: 1.2rem; |
| box-shadow: 0 1px 3px rgba(0,0,0,0.04); |
| } |
| |
| /* ββ Output panel (right column) ββ */ |
| .output-panel { |
| background: linear-gradient(180deg, #ffffff 0%, #f8fafc 100%); |
| border: 1px solid #e2e8f0; |
| border-radius: 14px; |
| padding: 1.2rem; |
| box-shadow: 0 1px 3px rgba(0,0,0,0.04); |
| } |
| .output-panel .plot-container { |
| max-height: 420px; |
| overflow-y: auto; |
| } |
| .output-panel .prose { |
| max-height: 420px; |
| overflow-y: auto; |
| } |
| |
| /* ββ Feedback card ββ */ |
| .feedback-card { |
| background: linear-gradient(135deg, #f8fafc 0%, #f1f5f9 100%); |
| border: 1px solid #e2e8f0; |
| border-radius: 14px; |
| padding: 1rem 1.4rem; |
| margin-top: 1rem; |
| box-shadow: 0 1px 3px rgba(0,0,0,0.04); |
| } |
| |
| /* ββ Classify button ββ */ |
| .classify-btn button { |
| border-radius: 10px !important; |
| } |
| |
| /* ββ Responsive ββ */ |
| @media (max-width: 980px) { |
| .gradio-container { |
| padding: 1rem !important; |
| } |
| .topbar { |
| padding: 1rem 1.2rem; |
| } |
| .input-panel, .output-panel { |
| min-width: 0 !important; |
| } |
| } |
| """ |
|
|
| TOPBAR_HTML = """ |
| <div class="topbar"> |
| <div class="topbar-title">Spam Email Classifier with XAI</div> |
| <div class="topbar-subtitle"> |
| Classify emails as spam or ham and understand <strong>why</strong> using |
| LIME, SHAP, and ELI5 explainable AI methods |
| </div> |
| <div class="topbar-badges"> |
| <span class="topbar-badge">Ensemble Model</span> |
| <span class="topbar-badge">LIME</span> |
| <span class="topbar-badge">SHAP</span> |
| <span class="topbar-badge">ELI5</span> |
| <span class="topbar-badge">97.4% Accuracy</span> |
| </div> |
| </div> |
| """ |
|
|
| with gr.Blocks(title="Spam Email Classifier with XAI", theme=theme, css=custom_css) as demo: |
| gr.HTML(TOPBAR_HTML) |
|
|
| hidden_state = gr.State("") |
|
|
| with gr.Row(equal_height=False): |
| with gr.Column(scale=2, min_width=360, elem_classes="input-panel"): |
| email_input = gr.Textbox( |
| label="Email Text", |
| placeholder="Paste your email here...", |
| lines=8, |
| autoscroll=False, |
| ) |
| file_input = gr.File( |
| label="Or upload a .txt file", |
| file_types=['.txt'], |
| ) |
| threshold_slider = gr.Slider( |
| minimum=0.0, maximum=1.0, step=0.05, |
| value=optimal_threshold if optimal_threshold else 0.5, |
| label="Classification Threshold", |
| info="Emails with spam probability above this are classified as spam.", |
| ) |
| classify_btn = gr.Button("Classify", variant="primary", size="lg", |
| elem_classes="classify-btn") |
| with gr.Accordion("Example Emails", open=False): |
| gr.Examples( |
| examples=EXAMPLE_EMAILS, |
| inputs=[email_input], |
| label="Click to load an example", |
| cache_examples=False, |
| ) |
|
|
| with gr.Column(scale=3, min_width=480, elem_classes="output-panel"): |
| with gr.Tabs(): |
| with gr.Tab("Result"): |
| result_output = gr.Markdown(label="Classification Result") |
| with gr.Tab("LIME"): |
| gr.Markdown("*LIME perturbs the input and fits a local model " |
| "to see which features matter most.*") |
| lime_output = gr.Plot(label="LIME Explanation") |
| with gr.Tab("SHAP"): |
| gr.Markdown("*SHAP uses game theory to assign each feature " |
| "a contribution value.*") |
| shap_output = gr.Plot(label="SHAP Explanation") |
| with gr.Tab("ELI5"): |
| gr.Markdown("*ELI5 shows feature weights directly from the " |
| "model's internals.*") |
| eli5_output = gr.HTML(label="ELI5 Explanation") |
| with gr.Tab("Compare"): |
| compare_output = gr.Markdown(label="Method Comparison") |
| with gr.Tab("Summary"): |
| summary_output = gr.Markdown(label="Plain English Summary") |
| with gr.Tab("How It Works"): |
| gr.Markdown(HOW_IT_WORKS_MD) |
|
|
| with gr.Group(elem_classes="feedback-card"): |
| with gr.Row(): |
| feedback_msg = gr.Markdown("**Was this classification correct?**") |
| correct_btn = gr.Button("Correct", variant="secondary", scale=0, |
| min_width=100) |
| wrong_btn = gr.Button("Wrong", variant="stop", scale=0, |
| min_width=100) |
| correction_dropdown = gr.Dropdown( |
| choices=["Spam", "Ham"], |
| label="Correct label", |
| scale=0, |
| min_width=120, |
| ) |
|
|
| classify_btn.click( |
| fn=classify_and_explain, |
| inputs=[email_input, file_input, threshold_slider], |
| outputs=[result_output, lime_output, shap_output, eli5_output, |
| compare_output, summary_output, hidden_state], |
| ) |
|
|
| correct_btn.click( |
| fn=handle_feedback, |
| inputs=[hidden_state, gr.State(True)], |
| outputs=[feedback_msg], |
| ) |
| wrong_btn.click( |
| fn=handle_feedback, |
| inputs=[hidden_state, gr.State(False), correction_dropdown], |
| outputs=[feedback_msg], |
| ) |
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) |
|
|