# app.py â Spam Email Classifier with Explanations # This file runs the Gradio web app. # It loads a trained model, classifies an email as spam or not spam, # and shows three different explanations of why it made that choice. # University course project â Explainable AI for spam detection # Features: LIME, SHAP, ELI5, side-by-side comparison, plain English summary, # user feedback logging, and batch retrain support. import csv import os from datetime import datetime from pathlib import Path import nltk nltk.download('stopwords', quiet=True) import eli5 import gradio as gr import lime import lime.lime_tabular import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import shap import joblib from scipy.sparse import hstack, csr_matrix from utils import (preprocess_text, compute_metadata_features, META_FEATURE_NAMES, FEATURE_DESCRIPTIONS) # --------------------------------------------------------------------------- # 1. Model Loading # --------------------------------------------------------------------------- models_dir = Path(__file__).parent / 'models' feedback_dir = Path(__file__).parent / 'feedback' feedback_dir.mkdir(exist_ok=True) FEEDBACK_CSV = feedback_dir / 'feedback_log.csv' try: voting_model = joblib.load(models_dir / 'voting_model.joblib') tfidf_vectorizer = joblib.load(models_dir / 'tfidf_vectorizer.joblib') meta_scaler = joblib.load(models_dir / 'meta_scaler.joblib') feature_names = joblib.load(models_dir / 'feature_names.joblib') optimal_threshold = joblib.load(models_dir / 'optimal_threshold.joblib') training_sample = joblib.load(models_dir / 'training_sample.joblib') # Works with both VotingClassifier (full train) and RandomForestClassifier (fast train) if hasattr(voting_model, 'named_estimators_'): raw_rf = voting_model.named_estimators_['rf'] else: raw_rf = voting_model print(f"All models loaded. Threshold = {optimal_threshold:.4f}") except FileNotFoundError as e: print(f"Model file not found: {e}") voting_model = None tfidf_vectorizer = None meta_scaler = None feature_names = None optimal_threshold = None training_sample = None raw_rf = None # --------------------------------------------------------------------------- # 2. LIME Explainer Setup # --------------------------------------------------------------------------- lime_explainer = None if training_sample is not None and feature_names is not None: lime_explainer = lime.lime_tabular.LimeTabularExplainer( training_data=training_sample, feature_names=feature_names, class_names=['Ham', 'Spam'], mode='classification', ) print("LIME explainer ready.") # --------------------------------------------------------------------------- # 3. classify_email # --------------------------------------------------------------------------- def classify_email(email_text, threshold): """Classify a single email. Returns (label, confidence, spam_proba, combined_features).""" cleaned_text = preprocess_text(email_text) tfidf_features = tfidf_vectorizer.transform([cleaned_text]) meta_raw = compute_metadata_features([email_text]) meta_scaled = meta_scaler.transform(meta_raw) combined = hstack([tfidf_features, csr_matrix(meta_scaled)]) spam_proba = voting_model.predict_proba(combined)[0][1] if spam_proba >= threshold: label = "SPAM" confidence = spam_proba else: label = "HAM (Not Spam)" confidence = 1.0 - spam_proba return label, confidence, spam_proba, combined # --------------------------------------------------------------------------- # 4. LIME explanation # --------------------------------------------------------------------------- def generate_lime_explanation(combined_features): """Generate LIME explanation. Returns (figure, explanation) or (None, None).""" if lime_explainer is None: return None, None instance = combined_features.toarray()[0] explanation = lime_explainer.explain_instance( instance, voting_model.predict_proba, num_features=10, ) fig = explanation.as_pyplot_figure() fig.tight_layout() return fig, explanation # --------------------------------------------------------------------------- # 5. SHAP explanation (metadata features only â fast) # --------------------------------------------------------------------------- # This function runs the model using only the metadata features, not the word features. # It fills in zeros for the word (TF-IDF) part so the model gets the right input shape. def predict_with_meta_only(meta_features, num_tfidf, model): # meta_features is a 2D array with one row per sample n_samples = meta_features.shape[0] # Create a block of zeros with the same number of columns as the TF-IDF features tfidf_zeros = csr_matrix((n_samples, num_tfidf)) # Stick the zeros and the metadata together side by side combined = hstack([tfidf_zeros, csr_matrix(meta_features)]) return model.predict_proba(combined) def generate_shap_explanation(email_text): """Generate SHAP bar chart for metadata features. Returns (figure, shap_values, top_indices) or (None, None, None).""" if training_sample is None or voting_model is None: return None, None, None num_meta = len(META_FEATURE_NAMES) background_meta = training_sample[:50, -num_meta:] meta_raw = compute_metadata_features([email_text]) meta_scaled = meta_scaler.transform(meta_raw) num_tfidf = training_sample.shape[1] - num_meta # Wrap the top-level function so SHAP can call it with just meta_features def shap_predict(meta_features): return predict_with_meta_only(meta_features, num_tfidf, voting_model) explainer = shap.KernelExplainer(shap_predict, background_meta) shap_values = explainer.shap_values(meta_scaled, nsamples=100) if isinstance(shap_values, list): sv = np.array(shap_values[1]).flatten() else: sv = np.array(shap_values).flatten() if len(sv) > num_meta: sv = sv[-num_meta:] top_idx = np.argsort(np.abs(sv))[::-1][:10] sorted_indices = np.argsort(np.abs(sv)) sorted_names = [META_FEATURE_NAMES[idx] for idx in sorted_indices.tolist()] sorted_values = sv[sorted_indices] fig, ax = plt.subplots(figsize=(8, 6)) colors = ['#d62728' if val > 0 else '#1f77b4' for val in sorted_values] ax.barh(sorted_names, sorted_values, color=colors) ax.set_xlabel('SHAP Value (impact on spam probability)') ax.set_title('SHAP Feature Importance (Metadata Features)') ax.axvline(x=0, color='black', linewidth=0.5) fig.tight_layout() return fig, sv, top_idx # --------------------------------------------------------------------------- # 6. ELI5 explanation # --------------------------------------------------------------------------- def generate_eli5_explanation(combined_features): """Generate ELI5 HTML and top feature names. Returns (html_string, feature_names_list) or (None, None).""" if raw_rf is None or feature_names is None: return None, None instance = combined_features.toarray()[0] eli5_exp = eli5.explain_prediction(raw_rf, instance, feature_names=feature_names, top=10) raw_html = eli5.format_as_html(eli5_exp) # ELI5 outputs a fragment (no
tag). Strip the boilerplate block
# â it contains methodology text that overflows the container and isn't useful
# for students. Keep the