VoltageVagabond's picture
Upload folder using huggingface_hub
960ec3d verified
Raw
History Blame
33.1 kB
# app.py β€” Spam Email Classifier with Explanations
# This file runs the Gradio web app.
# It loads a trained model, classifies an email as spam or not spam,
# and shows three different explanations of why it made that choice.
# University course project β€” Explainable AI for spam detection
# Features: LIME, SHAP, ELI5, side-by-side comparison, plain English summary,
# user feedback logging, and batch retrain support.
import csv
import os
from datetime import datetime
from pathlib import Path
import nltk
nltk.download('stopwords', quiet=True)
import eli5
import gradio as gr
import lime
import lime.lime_tabular
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import shap
import joblib
from scipy.sparse import hstack, csr_matrix
from utils import (preprocess_text, compute_metadata_features,
META_FEATURE_NAMES, FEATURE_DESCRIPTIONS)
# ---------------------------------------------------------------------------
# 1. Model Loading
# ---------------------------------------------------------------------------
models_dir = Path(__file__).parent / 'models'
feedback_dir = Path(__file__).parent / 'feedback'
feedback_dir.mkdir(exist_ok=True)
FEEDBACK_CSV = feedback_dir / 'feedback_log.csv'
try:
voting_model = joblib.load(models_dir / 'voting_model.joblib')
tfidf_vectorizer = joblib.load(models_dir / 'tfidf_vectorizer.joblib')
meta_scaler = joblib.load(models_dir / 'meta_scaler.joblib')
feature_names = joblib.load(models_dir / 'feature_names.joblib')
optimal_threshold = joblib.load(models_dir / 'optimal_threshold.joblib')
training_sample = joblib.load(models_dir / 'training_sample.joblib')
# Works with both VotingClassifier (full train) and RandomForestClassifier (fast train)
if hasattr(voting_model, 'named_estimators_'):
raw_rf = voting_model.named_estimators_['rf']
else:
raw_rf = voting_model
print(f"All models loaded. Threshold = {optimal_threshold:.4f}")
except FileNotFoundError as e:
print(f"Model file not found: {e}")
voting_model = None
tfidf_vectorizer = None
meta_scaler = None
feature_names = None
optimal_threshold = None
training_sample = None
raw_rf = None
# ---------------------------------------------------------------------------
# 2. LIME Explainer Setup
# ---------------------------------------------------------------------------
lime_explainer = None
if training_sample is not None and feature_names is not None:
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
training_data=training_sample,
feature_names=feature_names,
class_names=['Ham', 'Spam'],
mode='classification',
)
print("LIME explainer ready.")
# ---------------------------------------------------------------------------
# 3. classify_email
# ---------------------------------------------------------------------------
def classify_email(email_text, threshold):
"""Classify a single email. Returns (label, confidence, spam_proba, combined_features)."""
cleaned_text = preprocess_text(email_text)
tfidf_features = tfidf_vectorizer.transform([cleaned_text])
meta_raw = compute_metadata_features([email_text])
meta_scaled = meta_scaler.transform(meta_raw)
combined = hstack([tfidf_features, csr_matrix(meta_scaled)])
spam_proba = voting_model.predict_proba(combined)[0][1]
if spam_proba >= threshold:
label = "SPAM"
confidence = spam_proba
else:
label = "HAM (Not Spam)"
confidence = 1.0 - spam_proba
return label, confidence, spam_proba, combined
# ---------------------------------------------------------------------------
# 4. LIME explanation
# ---------------------------------------------------------------------------
def generate_lime_explanation(combined_features):
"""Generate LIME explanation. Returns (figure, explanation) or (None, None)."""
if lime_explainer is None:
return None, None
instance = combined_features.toarray()[0]
explanation = lime_explainer.explain_instance(
instance,
voting_model.predict_proba,
num_features=10,
)
fig = explanation.as_pyplot_figure()
fig.tight_layout()
return fig, explanation
# ---------------------------------------------------------------------------
# 5. SHAP explanation (metadata features only β€” fast)
# ---------------------------------------------------------------------------
# This function runs the model using only the metadata features, not the word features.
# It fills in zeros for the word (TF-IDF) part so the model gets the right input shape.
def predict_with_meta_only(meta_features, num_tfidf, model):
# meta_features is a 2D array with one row per sample
n_samples = meta_features.shape[0]
# Create a block of zeros with the same number of columns as the TF-IDF features
tfidf_zeros = csr_matrix((n_samples, num_tfidf))
# Stick the zeros and the metadata together side by side
combined = hstack([tfidf_zeros, csr_matrix(meta_features)])
return model.predict_proba(combined)
def generate_shap_explanation(email_text):
"""Generate SHAP bar chart for metadata features. Returns (figure, shap_values, top_indices) or (None, None, None)."""
if training_sample is None or voting_model is None:
return None, None, None
num_meta = len(META_FEATURE_NAMES)
background_meta = training_sample[:50, -num_meta:]
meta_raw = compute_metadata_features([email_text])
meta_scaled = meta_scaler.transform(meta_raw)
num_tfidf = training_sample.shape[1] - num_meta
# Wrap the top-level function so SHAP can call it with just meta_features
def shap_predict(meta_features):
return predict_with_meta_only(meta_features, num_tfidf, voting_model)
explainer = shap.KernelExplainer(shap_predict, background_meta)
shap_values = explainer.shap_values(meta_scaled, nsamples=100)
if isinstance(shap_values, list):
sv = np.array(shap_values[1]).flatten()
else:
sv = np.array(shap_values).flatten()
if len(sv) > num_meta:
sv = sv[-num_meta:]
top_idx = np.argsort(np.abs(sv))[::-1][:10]
sorted_indices = np.argsort(np.abs(sv))
sorted_names = [META_FEATURE_NAMES[idx] for idx in sorted_indices.tolist()]
sorted_values = sv[sorted_indices]
fig, ax = plt.subplots(figsize=(8, 6))
colors = ['#d62728' if val > 0 else '#1f77b4' for val in sorted_values]
ax.barh(sorted_names, sorted_values, color=colors)
ax.set_xlabel('SHAP Value (impact on spam probability)')
ax.set_title('SHAP Feature Importance (Metadata Features)')
ax.axvline(x=0, color='black', linewidth=0.5)
fig.tight_layout()
return fig, sv, top_idx
# ---------------------------------------------------------------------------
# 6. ELI5 explanation
# ---------------------------------------------------------------------------
def generate_eli5_explanation(combined_features):
"""Generate ELI5 HTML and top feature names. Returns (html_string, feature_names_list) or (None, None)."""
if raw_rf is None or feature_names is None:
return None, None
instance = combined_features.toarray()[0]
eli5_exp = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=10)
raw_html = eli5.format_as_html(eli5_exp)
# ELI5 outputs a fragment (no <body> tag). Strip the <pre> boilerplate block
# β€” it contains methodology text that overflows the container and isn't useful
# for students. Keep the <style> and the contribution table.
import re
clean_html = re.sub(r'<pre>.*?</pre>', '', raw_html, flags=re.DOTALL)
html = f'<div style="overflow-wrap: break-word; word-break: break-word;">{clean_html}</div>'
eli5_top5 = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=5)
top_names = []
if hasattr(eli5_top5, 'targets') and eli5_top5.targets:
for fw in eli5_top5.targets[0].feature_weights.pos[:5]:
top_names.append(fw.feature)
for fw in eli5_top5.targets[0].feature_weights.neg[:5]:
top_names.append(fw.feature)
return html, top_names
# ---------------------------------------------------------------------------
# 7. Plain English summary (replaces Ollama LLM)
# ---------------------------------------------------------------------------
# Pick a color, icon, and verdict text based on the label and confidence score.
def get_result_badge(label, confidence):
"""Return a dict with color, icon, and text describing the classification result."""
# Start with empty values and fill them in below
badge_color = ""
badge_icon = ""
badge_text = ""
# Check whether the model predicted spam or ham, then pick values
# based on how confident it is (high = above 90%, medium = above 70%)
if "SPAM" in label:
if confidence > 0.9:
badge_color = "red"
badge_icon = "🚨"
badge_text = "The model is highly confident this email contains patterns commonly seen in spam or phishing attempts."
elif confidence > 0.7:
badge_color = "orange"
badge_icon = "⚠️"
badge_text = "The model found several spam-like patterns in this email."
else:
badge_color = "yellow"
badge_icon = "❓"
badge_text = "The model leans toward spam, but the evidence is not overwhelming. Use your judgment."
else:
# Label is HAM (legitimate email)
if confidence > 0.9:
badge_color = "green"
badge_icon = "βœ…"
badge_text = "The model is highly confident this is a legitimate email."
elif confidence > 0.7:
badge_color = "blue"
badge_icon = "ℹ️"
badge_text = "The model found this email to be mostly consistent with legitimate messages."
else:
badge_color = "gray"
badge_icon = "❓"
badge_text = "The model leans toward legitimate, but there are some spam-like features. Review carefully."
# Pack the three values into a plain dictionary and return it
result = {"color": badge_color, "icon": badge_icon, "text": badge_text}
return result
def generate_plain_summary(label, confidence, spam_proba, lime_explanation,
shap_sv, shap_top_idx):
"""Build a rule-based plain English summary from XAI results."""
summary = f"### Classification: **{label}** ({confidence:.0%} confidence)\n\n"
if lime_explanation is not None:
feature_list = lime_explanation.as_list()
summary += "**Key words driving this decision (LIME):**\n"
for feat_rule, weight in feature_list[:3]:
direction = "pushes toward spam" if weight > 0 else "pushes toward ham"
summary += f"- **{feat_rule}** β€” {direction}\n"
summary += "\n"
if shap_sv is not None and shap_top_idx is not None:
summary += "**Important email characteristics (SHAP):**\n"
for i in shap_top_idx[:2]:
feat_name = META_FEATURE_NAMES[i]
description = FEATURE_DESCRIPTIONS.get(feat_name, feat_name)
direction = "spam signal" if shap_sv[i] > 0 else "ham signal"
summary += f"- **{feat_name}** ({description}) β€” {direction}\n"
summary += "\n"
if lime_explanation is not None and shap_top_idx is not None:
lime_top = set(f[0] for f in lime_explanation.as_list()[:10])
shap_top = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10])
overlap = lime_top & shap_top
if overlap:
summary += f"**Method agreement:** LIME and SHAP both flag: {', '.join(sorted(overlap))}\n\n"
# Get badge values (color, icon, verdict text) for this label + confidence
badge = get_result_badge(label, confidence)
summary += badge["text"]
return summary
# ---------------------------------------------------------------------------
# 8. Side-by-side comparison
# ---------------------------------------------------------------------------
def get_top_features(explanation, method_name):
"""Return a list of up to 3 plain feature name strings from an explanation object.
Each XAI tool returns a different object type, so we use method_name to
decide how to extract the names.
Parameters
----------
explanation : the explanation object (LIME Explanation, numpy array, or list)
method_name : one of "lime", "shap", or "eli5"
Returns
-------
A plain Python list of feature name strings (up to 3 items).
"""
# Start with an empty list; we will fill it in the if/elif blocks below
feature_names = []
if method_name == "lime":
# LIME returns an Explanation object with .as_list() -> [(feature_name, weight), ...]
if explanation is not None:
for feat, w in explanation.as_list()[:3]:
feature_names.append(feat)
elif method_name == "shap":
# SHAP gives us a 1-D numpy array of values plus a separate index array
# explanation here is a tuple: (shap_values_array, top_indices_array)
if explanation is not None:
shap_sv, shap_top_idx = explanation
if shap_sv is not None and shap_top_idx is not None:
for i in shap_top_idx[:3]:
feature_names.append(META_FEATURE_NAMES[i])
elif method_name == "eli5":
# ELI5 already gives us a plain list of feature name strings
if explanation is not None:
for name in explanation[:3]:
feature_names.append(name)
return feature_names
def generate_comparison(lime_explanation, shap_sv, shap_top_idx, eli5_names):
"""Build a markdown comparison of top features from each XAI method."""
md = "### Side-by-Side: Top Features by Method\n\n"
md += "| Rank | LIME | SHAP (metadata) | ELI5 |\n"
md += "|------|------|-----------------|------|\n"
lime_top5 = []
if lime_explanation is not None:
for feat, w in lime_explanation.as_list()[:5]:
direction = "spam" if w > 0 else "ham"
lime_top5.append(f"{feat} ({direction}, {w:+.3f})")
shap_top5 = []
if shap_sv is not None and shap_top_idx is not None:
for i in shap_top_idx[:5]:
direction = "spam" if shap_sv[i] > 0 else "ham"
shap_top5.append(f"{META_FEATURE_NAMES[i]} ({direction}, {shap_sv[i]:+.3f})")
eli5_top5 = (eli5_names or [])[:5]
for rank in range(5):
lime_cell = lime_top5[rank] if rank < len(lime_top5) else "β€”"
shap_cell = shap_top5[rank] if rank < len(shap_top5) else "β€”"
eli5_cell = eli5_top5[rank] if rank < len(eli5_top5) else "β€”"
md += f"| {rank+1} | {lime_cell} | {shap_cell} | {eli5_cell} |\n"
# Use the helper to get plain feature names for the three-way overlap check
lime_names = get_top_features(lime_explanation, "lime")
shap_names = get_top_features((shap_sv, shap_top_idx), "shap")
eli5_names_short = get_top_features(eli5_names, "eli5")
# Find features that appear in all three top-3 lists using a plain for loop
overlap = []
for name in lime_names:
if name in shap_names:
if name in eli5_names_short:
overlap.append(name)
if lime_explanation is not None and shap_top_idx is not None:
lime_set = set(f[0] for f in lime_explanation.as_list()[:10])
shap_set = set(META_FEATURE_NAMES[i] for i in shap_top_idx[:10])
overlap_lime_shap = lime_set & shap_set
md += f"\n**LIME-SHAP agreement** (top 10): **{len(overlap_lime_shap)}** shared features"
if overlap_lime_shap:
md += f"\nShared: {', '.join(sorted(overlap_lime_shap))}"
md += "\n\n*Note: LIME covers all features (words + metadata), SHAP covers only the 24 metadata features, "
md += "ELI5 uses the Random Forest sub-estimator's internal weights.*"
return md
# ---------------------------------------------------------------------------
# 9. Feedback logging
# ---------------------------------------------------------------------------
def log_feedback(email_text, predicted_label, predicted_confidence, threshold,
feedback_type, correct_label=None):
"""Append one feedback row to the CSV log."""
write_header = not FEEDBACK_CSV.exists()
with open(FEEDBACK_CSV, 'a', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
if write_header:
writer.writerow(['timestamp', 'email_text', 'predicted_label',
'predicted_confidence', 'feedback', 'correct_label',
'threshold_used'])
writer.writerow([
datetime.now().isoformat(),
email_text[:500],
predicted_label,
f"{predicted_confidence:.4f}",
feedback_type,
correct_label or '',
f"{threshold:.4f}",
])
return count_corrections()
def count_corrections():
"""Count the number of 'wrong' entries in the feedback log."""
if not FEEDBACK_CSV.exists():
return 0
count = 0
with open(FEEDBACK_CSV, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
if row.get('feedback') == 'wrong':
count += 1
return count
# ---------------------------------------------------------------------------
# 10. Example Emails
# ---------------------------------------------------------------------------
EXAMPLE_EMAILS = [
["Subject: URGENT - You Have Won $5,000,000!!!\n\nDear Friend,\n\nCONGRATULATIONS!!! You have been selected as the winner of our international lottery program!!!\nTo claim your $5,000,000 USD prize, click the link below IMMEDIATELY and provide your bank details.\n\nACT NOW - This offer expires in 24 hours!!!\n\nClick here: http://totally-legit-prize.com/claim\nSend $500 processing fee to unlock your winnings.\n\nBest regards,\nDr. Prince Mohammed"],
["Subject: Team sync Thursday 2pm\n\nHi everyone,\n\nJust a reminder that we have our weekly team sync this Thursday at 2pm in Conference Room B.\n\nAgenda:\n- Sprint review\n- Q2 planning discussion\n- New hire onboarding update\n\nPlease come prepared with your status updates.\n\nThanks,\nSarah"],
["Subject: Your account has been compromised!\n\nDear Customer,\n\nWe detected suspicious activity on your account. Click here immediately to verify your identity: http://secure-bank-login.com/verify\n\nIf you do not verify within 24 hours, your account will be permanently locked.\n\nSecurity Team"],
["Subject: Thanksgiving dinner plans\n\nHi everyone!\n\nI wanted to start planning for Thanksgiving dinner. I'm thinking we could do it at my place this year. What does everyone think about 4pm?\n\nLet me know if you have any dietary restrictions or if you want to bring a dish.\n\nLove,\nMom"],
["Subject: Best prices on V1AGRA and C1ALIS!!!\n\n$$$ SAVE BIG $$$\nBuy now and get 80% OFF!!!\nNo prescription needed! Free shipping!\nOrder at http://cheap-pharma-deals.com\n\nLIMITED TIME OFFER - ACT NOW!"],
]
# ---------------------------------------------------------------------------
# 11. Main orchestration function
# ---------------------------------------------------------------------------
def classify_and_explain(email_text, uploaded_file, threshold):
"""Main function called by Gradio. Returns all outputs for all tabs + feedback state."""
# Figure out if the user pasted text or uploaded a file
if uploaded_file is not None:
try:
file_content = Path(uploaded_file).read_text(encoding='utf-8')
email_text = file_content
except Exception:
empty = ("Could not read file.", None, None, "Error reading file.", "", "", "")
return empty
if email_text is None or email_text.strip() == '':
empty = ("Please enter email text or upload a file.", None, None, "", "", "", "")
return empty
if voting_model is None:
empty = ("Models not found. Run `python3 train.py` first.", None, None, "", "", "", "")
return empty
# Run the email through the model to get spam/ham prediction
label, confidence, spam_proba, combined = classify_email(email_text, threshold)
# Generate LIME explanation
lime_fig, lime_exp = generate_lime_explanation(combined)
# Generate SHAP explanation
try:
shap_fig, shap_sv, shap_top_idx = generate_shap_explanation(email_text)
except Exception as e:
print(f"SHAP error: {e}")
shap_fig, shap_sv, shap_top_idx = None, None, None
# Generate ELI5 explanation
try:
eli5_html, eli5_names = generate_eli5_explanation(combined)
except Exception as e:
print(f"ELI5 error: {e}")
eli5_html, eli5_names = None, None
result_md = f"## {'SPAM' if 'SPAM' in label else 'HAM (Not Spam)'}\n\n"
result_md += f"**Confidence:** {confidence:.1%}\n\n"
result_md += f"**Threshold:** {threshold:.0%}\n\n"
result_md += f"**Spam probability:** {spam_proba:.1%}\n\n"
if lime_exp is not None:
result_md += "**Key factors:**\n"
for feat_rule, weight in lime_exp.as_list()[:5]:
direction = "pushes toward spam" if weight > 0 else "pushes toward ham"
result_md += f"- **{feat_rule}** {direction}\n"
comparison_md = generate_comparison(lime_exp, shap_sv, shap_top_idx, eli5_names)
summary_md = generate_plain_summary(label, confidence, spam_proba, lime_exp, shap_sv, shap_top_idx)
eli5_display = eli5_html or "<p>ELI5 explanation not available.</p>"
# Pack the email and prediction into one string so the feedback buttons can use it later
# Send all results back to the Gradio interface
return (result_md, lime_fig, shap_fig, eli5_display, comparison_md, summary_md,
f"{label}|||{confidence:.4f}|||{threshold:.4f}|||{email_text[:500]}")
# ---------------------------------------------------------------------------
# 12. Feedback handlers
# ---------------------------------------------------------------------------
def handle_feedback(hidden_state, is_correct, user_label=""):
# hidden_state is a string that holds the email text and prediction packed together
# is_correct is True if the user said the prediction was right, False if wrong
# user_label is the correct label the user provides when is_correct is False
# Make sure a classification has actually been run before giving feedback
if not hidden_state:
return "No classification to give feedback on."
# Unpack the hidden state β€” values are joined with ||| as a separator
parts = hidden_state.split('|||')
if len(parts) < 4:
return "No classification to give feedback on."
# Pull out each piece that was packed into the hidden state
predicted_label = parts[0]
predicted_confidence = float(parts[1])
threshold_used = float(parts[2])
email_text = parts[3]
if is_correct:
# User confirmed the prediction was right β€” log it as a correct response
corrections = log_feedback(email_text, predicted_label, predicted_confidence,
threshold_used, 'correct')
return f"Thanks for the feedback! ({corrections} corrections collected so far)"
else:
# User said the prediction was wrong β€” log it with their correction
corrections = log_feedback(email_text, predicted_label, predicted_confidence,
threshold_used, 'wrong', user_label)
return f"Correction logged! ({corrections} corrections collected so far)"
# ---------------------------------------------------------------------------
# 13. Gradio Blocks UI
# ---------------------------------------------------------------------------
HOW_IT_WORKS_MD = """
## How This App Works
### What is spam classification?
Spam classification automatically identifies unwanted or malicious emails (spam) vs. legitimate messages (ham). This helps protect users from phishing scams, fraudulent offers, and unwanted advertising.
### The Model
This app uses a **Voting Ensemble** β€” three different machine learning models that each "vote" on whether an email is spam:
- **Random Forest** β€” builds many decision trees and takes the majority vote
- **Logistic Regression** β€” finds a mathematical boundary between spam and ham
- **Support Vector Machine (SVM)** β€” finds the widest possible margin between classes
By combining all three, the ensemble is more accurate than any single model alone.
### Feature Extraction
The model looks at two types of features:
- **TF-IDF (Term Frequency-Inverse Document Frequency)** β€” measures how important each word is. Common spam words like "prize" or "click" get high scores.
- **24 Metadata Features** β€” structural patterns like exclamation mark density, dollar sign count, ALL CAPS ratio, URL count, and more.
### Explainable AI (XAI) Methods
This app doesn't just classify β€” it explains **why**:
- **LIME** β€” Removes words one at a time and watches how the prediction changes. Shows which words matter most.
- **SHAP** β€” Uses game theory to calculate each feature's "fair share" of the prediction. Based on Nobel Prize-winning mathematics.
- **ELI5** β€” Looks directly at the model's internal weights to show which features it relies on most.
### Feedback & Retraining
When you click "Correct" or "Wrong", your feedback is saved. After enough corrections accumulate, the model can be retrained with the new examples to improve over time. This is called **human-in-the-loop machine learning**.
### Disclaimer
This model was created as a university course project. It is intended for **educational and research purposes only** and should not be used as a sole spam filter in production. Always use established email security tools for real-world spam filtering.
"""
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="red",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
font_mono=gr.themes.GoogleFont("IBM Plex Mono"),
)
custom_css = """
/* ── Container ── */
.gradio-container {
max-width: 1600px !important;
margin: 0 auto !important;
padding: 1.5rem 2rem !important;
}
/* ── Top bar ── */
.topbar {
background: linear-gradient(135deg, #f8fafc 0%, #eef2ff 100%);
border: 1px solid #e2e8f0;
border-radius: 14px;
padding: 1.4rem 1.8rem 1.2rem;
margin-bottom: 1.2rem;
box-shadow: 0 1px 3px rgba(0,0,0,0.06);
text-align: center;
}
.topbar-title {
font-size: 22px;
font-weight: 700;
color: #1e293b;
margin: 0 0 0.3rem;
}
.topbar-subtitle {
font-size: 13px;
color: #64748b;
margin: 0 0 0.7rem;
}
.topbar-badges {
display: flex;
justify-content: center;
gap: 0.5rem;
flex-wrap: wrap;
}
.topbar-badge {
display: inline-block;
background: #e0e7ff;
color: #3730a3;
font-size: 11.5px;
font-weight: 600;
padding: 0.25rem 0.7rem;
border-radius: 999px;
letter-spacing: 0.02em;
}
/* ── Input panel (left column) ── */
.input-panel {
background: linear-gradient(180deg, #ffffff 0%, #f8fafc 100%);
border: 1px solid #e2e8f0;
border-radius: 14px;
padding: 1.2rem;
box-shadow: 0 1px 3px rgba(0,0,0,0.04);
}
/* ── Output panel (right column) ── */
.output-panel {
background: linear-gradient(180deg, #ffffff 0%, #f8fafc 100%);
border: 1px solid #e2e8f0;
border-radius: 14px;
padding: 1.2rem;
box-shadow: 0 1px 3px rgba(0,0,0,0.04);
}
.output-panel .plot-container {
max-height: 420px;
overflow-y: auto;
}
.output-panel .prose {
max-height: 420px;
overflow-y: auto;
}
/* ── Feedback card ── */
.feedback-card {
background: linear-gradient(135deg, #f8fafc 0%, #f1f5f9 100%);
border: 1px solid #e2e8f0;
border-radius: 14px;
padding: 1rem 1.4rem;
margin-top: 1rem;
box-shadow: 0 1px 3px rgba(0,0,0,0.04);
}
/* ── Classify button ── */
.classify-btn button {
border-radius: 10px !important;
}
/* ── Responsive ── */
@media (max-width: 980px) {
.gradio-container {
padding: 1rem !important;
}
.topbar {
padding: 1rem 1.2rem;
}
.input-panel, .output-panel {
min-width: 0 !important;
}
}
"""
TOPBAR_HTML = """
<div class="topbar">
<div class="topbar-title">Spam Email Classifier with XAI</div>
<div class="topbar-subtitle">
Classify emails as spam or ham and understand <strong>why</strong> using
LIME, SHAP, and ELI5 explainable AI methods
</div>
<div class="topbar-badges">
<span class="topbar-badge">Ensemble Model</span>
<span class="topbar-badge">LIME</span>
<span class="topbar-badge">SHAP</span>
<span class="topbar-badge">ELI5</span>
<span class="topbar-badge">97.4% Accuracy</span>
</div>
</div>
"""
with gr.Blocks(title="Spam Email Classifier with XAI", theme=theme, css=custom_css) as demo:
gr.HTML(TOPBAR_HTML)
hidden_state = gr.State("")
with gr.Row(equal_height=False):
with gr.Column(scale=2, min_width=360, elem_classes="input-panel"):
email_input = gr.Textbox(
label="Email Text",
placeholder="Paste your email here...",
lines=8,
autoscroll=False,
)
file_input = gr.File(
label="Or upload a .txt file",
file_types=['.txt'],
)
threshold_slider = gr.Slider(
minimum=0.0, maximum=1.0, step=0.05,
value=optimal_threshold if optimal_threshold else 0.5,
label="Classification Threshold",
info="Emails with spam probability above this are classified as spam.",
)
classify_btn = gr.Button("Classify", variant="primary", size="lg",
elem_classes="classify-btn")
with gr.Accordion("Example Emails", open=False):
gr.Examples(
examples=EXAMPLE_EMAILS,
inputs=[email_input],
label="Click to load an example",
cache_examples=False,
)
with gr.Column(scale=3, min_width=480, elem_classes="output-panel"):
with gr.Tabs():
with gr.Tab("Result"):
result_output = gr.Markdown(label="Classification Result")
with gr.Tab("LIME"):
gr.Markdown("*LIME perturbs the input and fits a local model "
"to see which features matter most.*")
lime_output = gr.Plot(label="LIME Explanation")
with gr.Tab("SHAP"):
gr.Markdown("*SHAP uses game theory to assign each feature "
"a contribution value.*")
shap_output = gr.Plot(label="SHAP Explanation")
with gr.Tab("ELI5"):
gr.Markdown("*ELI5 shows feature weights directly from the "
"model's internals.*")
eli5_output = gr.HTML(label="ELI5 Explanation")
with gr.Tab("Compare"):
compare_output = gr.Markdown(label="Method Comparison")
with gr.Tab("Summary"):
summary_output = gr.Markdown(label="Plain English Summary")
with gr.Tab("How It Works"):
gr.Markdown(HOW_IT_WORKS_MD)
with gr.Group(elem_classes="feedback-card"):
with gr.Row():
feedback_msg = gr.Markdown("**Was this classification correct?**")
correct_btn = gr.Button("Correct", variant="secondary", scale=0,
min_width=100)
wrong_btn = gr.Button("Wrong", variant="stop", scale=0,
min_width=100)
correction_dropdown = gr.Dropdown(
choices=["Spam", "Ham"],
label="Correct label",
scale=0,
min_width=120,
)
classify_btn.click(
fn=classify_and_explain,
inputs=[email_input, file_input, threshold_slider],
outputs=[result_output, lime_output, shap_output, eli5_output,
compare_output, summary_output, hidden_state],
)
correct_btn.click(
fn=handle_feedback,
inputs=[hidden_state, gr.State(True)],
outputs=[feedback_msg],
)
wrong_btn.click(
fn=handle_feedback,
inputs=[hidden_state, gr.State(False), correction_dropdown],
outputs=[feedback_msg],
)
# ---------------------------------------------------------------------------
# 14. Launch
# ---------------------------------------------------------------------------
if __name__ == '__main__':
demo.queue().launch(server_name="0.0.0.0", server_port=7860)