# Spam Email Classifier with XAI Explanations # ENGT 375 Project - Spring 2026 - ODU # Uses LIME, SHAP, ELI5, and Ollama/Qwen 3.5 (2b) for explanations import os import streamlit as st import numpy as np import pandas as pd import joblib import re import json import requests import shutil import matplotlib.pyplot as plt import matplotlib.colors as mcolors import shap import lime import lime.lime_tabular import eli5 from pathlib import Path # I use sparse matrices here because TF-IDF creates a huge matrix and # converting to dense would use way too much memory (found this on Stack Overflow) from scipy.sparse import hstack, csr_matrix from utils_student import (preprocess_text, compute_metadata_features, spam_context_phrases, ham_context_phrases, registration_phrases, url_shorteners, legitimate_platforms, OLLAMA_API, LLM_FEATURE_NAMES) # Project directories PROJECT_DIR = Path(__file__).parent MODELS_DIR = PROJECT_DIR / 'models' # Trusted domain whitelist (app-specific, not shared) trusted_domains = {'.gov', '.mil', '.edu', 'govdelivery.com', 'granicus.com'} # I used JavaScript here because Streamlit doesn't save settings between sessions # by default - streamlit_js_eval lets me store the classification results in the # browser's localStorage so they persist when the page refreshes from streamlit_js_eval import streamlit_js_eval STORAGE_KEY = "spam_xai_state" def _make_serializable(obj): """Convert numpy arrays and other non-JSON types to serializable forms.""" if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, dict): return {k: _make_serializable(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [_make_serializable(item) for item in obj] return obj def save_to_local_storage(data: dict): """Serialize and save state to browser localStorage.""" serializable = _make_serializable(data) payload = json.dumps(serializable) js = "localStorage.setItem('%s', JSON.stringify(%s))" % (STORAGE_KEY, payload) streamlit_js_eval(js_expressions=js, key="save_%d" % hash(payload)) def load_from_local_storage(): """Load state from browser localStorage. Returns dict or None.""" result = streamlit_js_eval( js_expressions="localStorage.getItem('%s')" % STORAGE_KEY, key="load_state" ) if result: try: return json.loads(result) except (json.JSONDecodeError, ValueError): return None return None def clear_local_storage(): """Remove saved state from localStorage.""" streamlit_js_eval( js_expressions="localStorage.removeItem('%s')" % STORAGE_KEY, key="clear_state" ) def save_feedback(email_text, predicted_label, correct_label, spam_prob, notes, feedback_type): """Save feedback to CSV and return total row count.""" import csv, datetime feedback_dir = Path(__file__).parent / 'data' / 'feedback' feedback_dir.mkdir(parents=True, exist_ok=True) feedback_file = feedback_dir / 'feedback_log.csv' file_exists = feedback_file.exists() with open(feedback_file, 'a', newline='', encoding='utf-8') as f: writer = csv.writer(f) if not file_exists: writer.writerow(['timestamp', 'email_text', 'predicted_label', 'correct_label', 'spam_prob', 'user_notes', 'feedback_type']) writer.writerow([ datetime.datetime.now().isoformat(), email_text[:2000], predicted_label, correct_label, '%.4f' % spam_prob, notes, feedback_type ]) row_count = sum(1 for _ in open(feedback_file, encoding='utf-8')) - 1 return row_count def extract_llm_features_single(text, model_name='qwen3.5:2b'): """Extract intent and tone features for a single email via Ollama.""" truncated = text[:500] prompt = ( 'Rate this email on these dimensions (0.0 to 1.0).\n' 'Respond with ONLY valid JSON: {"promotional": X, "transactional": X, ' '"personal": X, "phishing": X, "urgency": X, "formality": X}\n' '/no_think\n\n' 'Email: "%s"' % truncated ) try: resp = requests.post(OLLAMA_API, json={ 'model': model_name, 'messages': [{'role': 'user', 'content': prompt}], 'stream': False, 'think': False, 'options': {'temperature': 0.1, 'num_predict': 100} }, timeout=30) if resp.status_code == 200: content = resp.json().get('message', {}).get('content', '') content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() json_match = re.search(r'\{[^}]+\}', content) if json_match: data = json.loads(json_match.group()) return np.array([ float(data.get('promotional', 0.5)), float(data.get('transactional', 0.5)), float(data.get('personal', 0.5)), float(data.get('phishing', 0.5)), float(data.get('urgency', 0.5)), float(data.get('formality', 0.5)), ]).reshape(1, -1) except Exception: pass return np.full((1, 6), 0.5) def _deserialize_results(data): """Convert lists back to numpy arrays for results dict.""" if data is None: return None for key in ('proba', 'original_proba', 'sv', 'meta'): if key in data and isinstance(data[key], list): data[key] = np.array(data[key]) if 'top_idx' in data and isinstance(data['top_idx'], list): data['top_idx'] = np.array(data['top_idx'], dtype=int) return data # Example emails for testing EXAMPLE_EMAILS = { "Spam: Nigerian Prince": """Subject: URGENT - You Have Won $5,000,000!!! Dear Friend, CONGRATULATIONS!!! You have been selected as the winner of our international lottery program!!! To claim your $5,000,000 USD prize, click the link below IMMEDIATELY and provide your bank details. ACT NOW - This offer expires in 24 hours!!! Click here: http://totally-legit-prize.com/claim Send $500 processing fee to unlock your winnings. Best regards, Dr. Prince Mohammed""", "Spam: Viagra Ad": """Subject: Best prices on V1AGRA and C1ALIS!!! $$$ SAVE BIG $$$ Buy now and get 80% OFF!!! No prescription needed! Free shipping! Order at http://cheap-pharma-deals.com LIMITED TIME OFFER - ACT NOW! Subscribe to our mailing list for more deals!""", "Ham: Meeting Invite": """Subject: Team sync Thursday 2pm Hi everyone, Just a reminder that we have our weekly team sync this Thursday at 2pm in Conference Room B. Agenda: - Sprint review - Q2 planning discussion - New hire onboarding update Please come prepared with your status updates. If you can't make it, let me know and I'll share the notes. Thanks, Sarah""", "Ham: Tech Discussion": """Subject: Re: Python 3.12 upgrade Hey Mike, I tested the upgrade on our staging environment yesterday. Everything looks good except for one deprecation warning in the logging module. I've already submitted a PR to fix it. The new pattern matching syntax is really nice for our parser module. Want to pair on refactoring that section next week? Also, did you see the new asyncio improvements? Could simplify our event loop code significantly. Cheers, Dave""", "Ham: Family Email": """Subject: Thanksgiving dinner plans Hi everyone! Hope you're all doing well. I wanted to start planning for Thanksgiving dinner this year. Mom and Dad said they can host again. I was thinking we could do a potluck style - I'll bring the turkey and stuffing, and everyone else can sign up for sides and desserts. Can everyone reply with what they'd like to bring? Also let me know if you have any dietary restrictions I should know about. Love, Jenny""" } # Check if the sender domain is in our trusted list (.gov, .edu, etc.) # I added this because government and university emails were getting flagged as spam def check_domain_trust(email_text): # Extract email addresses from From: headers or general text from_match = re.search(r'(?:From|Return-Path|Sender):\s*.*?([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})', email_text, re.IGNORECASE) if not from_match: # Fallback: any email address in the text email_match = re.search(r'[a-zA-Z0-9._%+-]+@([a-zA-Z0-9.-]+\.[a-zA-Z]{2,})', email_text) if email_match: domain = email_match.group(1).lower() else: return {'trusted': False, 'domain': None, 'match': None} else: domain = from_match.group(1).split('@')[-1].lower() for t in trusted_domains: if domain.endswith(t.lstrip('.')): return {'trusted': True, 'domain': domain, 'match': t} return {'trusted': False, 'domain': domain, 'match': None} # Detect email header signals for post-classification adjustment # I learned about these email headers from reading about how Gmail filters work def extract_header_features(text): features = {} # List-Unsubscribe header - legitimate newsletters include this so users # can unsubscribe (spammers usually don't bother adding it) features['has_list_unsubscribe'] = bool(re.search(r'List-Unsubscribe:', text, re.IGNORECASE)) # Sender domain checks from_match = re.search(r'(?:From|Return-Path|Sender):\s*.*?@([a-zA-Z0-9.-]+\.[a-zA-Z]{2,})', text, re.IGNORECASE) sender_domain = from_match.group(1).lower() if from_match else '' features['sender_is_gov'] = sender_domain.endswith('.gov') or sender_domain.endswith('.mil') # Authentication results features['has_spf_pass'] = bool(re.search(r'spf=pass', text, re.IGNORECASE)) features['has_dkim_pass'] = bool(re.search(r'dkim=pass', text, re.IGNORECASE)) return features # Load the trained model and supporting files # I load all the XAI libraries and model files here at startup - caching them with # @st.cache_resource means they stay in memory so every classification is fast @st.cache_resource def load_model(): print('Loading the trained model...') model = joblib.load(MODELS_DIR / 'random_forest_spam.joblib') vectorizer = joblib.load(MODELS_DIR / 'tfidf_vectorizer.joblib') feature_names = joblib.load(MODELS_DIR / 'feature_names.joblib') training_sample_path = MODELS_DIR / 'training_sample.joblib' training_sample = joblib.load(training_sample_path) if training_sample_path.exists() else None # I need the raw RF model separately because SHAP and ELI5 can't work # with the calibrated wrapper - they need to see the actual tree structure raw_rf_path = MODELS_DIR / 'random_forest_raw.joblib' raw_rf = joblib.load(raw_rf_path) if raw_rf_path.exists() else model # Load optimal threshold if available threshold_path = MODELS_DIR / 'optimal_threshold.joblib' optimal_threshold = joblib.load(threshold_path) if threshold_path.exists() else 0.60 # Load training config (tracks whether LLM features were used) config_path = MODELS_DIR / 'training_config.joblib' training_config = joblib.load(config_path) if config_path.exists() else {'llm_features_used': False} # Load metadata scaler if available scaler_path = MODELS_DIR / 'meta_scaler.joblib' meta_scaler = joblib.load(scaler_path) if scaler_path.exists() else None print('Done!') return model, vectorizer, feature_names, training_sample, raw_rf, optimal_threshold, training_config, meta_scaler # Cache LIME explainer so it's not recreated on every classification @st.cache_resource def get_lime_explainer(_training_data, feature_names): # LIME needs a sample of training data to understand what "normal" looks like # so it can measure how much each word changes the prediction print('Creating LIME explainer (cached)...') return lime.lime_tabular.LimeTabularExplainer( training_data=_training_data, feature_names=feature_names, class_names=['Ham', 'Spam'], mode='classification' ) # Cache SHAP explainer so it's not recreated on every classification @st.cache_resource def get_shap_explainer(_raw_rf): # I use TreeExplainer because it's designed specifically for tree-based models # like Random Forest - it's much faster than the generic KernelExplainer print('Creating SHAP TreeExplainer (cached)...') return shap.TreeExplainer(_raw_rf) def check_ollama(): # Check if Ollama is running - prefer qwen3.5:2b, fall back to qwen3.5, then gemma3 try: resp = requests.get('http://localhost:11434/api/tags', timeout=2) if resp.status_code == 200: models = [m['name'] for m in resp.json().get('models', [])] preferred = [m for m in models if 'qwen3.5:2b' in m] fallback_qwen = [m for m in models if 'qwen3.5' in m] fallback_gemma = [m for m in models if 'gemma3' in m] return preferred or fallback_qwen or fallback_gemma return [] except Exception: return [] def get_llm_explanation(email_text, label, confidence, proba, lime_features, shap_features, eli5_features, model_name): # Get natural language explanation from Ollama Qwen 3.5 print('Running LLM explanation...') truncated = email_text[:500] + ('...' if len(email_text) > 500 else '') spam_or_not = "spam" if label == "SPAM" else "NOT spam" lime_part = ', '.join([f'{name} ({weight:+.3f})' for name, weight in lime_features[:5]]) shap_part = ', '.join([f'{name} ({val:+.3f})' for name, val in shap_features[:5]]) eli5_part = ', '.join([str(name) for name in eli5_features[:5]]) prompt = ( f'FACT: This email has been classified as {label} with {confidence * 100:.1f}% confidence.\n' f'Your job is to explain WHY it is {label}, not to reclassify it.\n' f'Ham probability: {proba[0] * 100:.1f}% | Spam probability: {proba[1] * 100:.1f}%\n\n' f'Email (truncated): "{truncated}"\n\n' f'Top features driving this decision:\n' f'- LIME: {lime_part}\n' f'- SHAP: {shap_part}\n' f'- ELI5: {eli5_part}\n\n' f'You are an email security analyst explaining this to a non-technical user.\n' f'Do NOT contradict the classification. The email IS {spam_or_not}.\n\n' f'In 3-5 sentences, explain why this email is {label}. ' f'Reference specific words or patterns from the email. ' f'Do not use technical jargon like "TF-IDF" or "SHAP values". /no_think' ) try: resp = requests.post(OLLAMA_API, json={ 'model': model_name, 'messages': [{'role': 'user', 'content': prompt}], 'stream': False, 'think': False, 'options': {'temperature': 0.3, 'num_predict': 300} }, timeout=120) if resp.status_code == 200: content = resp.json().get('message', {}).get('content', '') # Strip any residual thinking tags content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() print('Done!') return content or 'No response generated.' return 'Ollama error: HTTP %d' % resp.status_code except (requests.RequestException, ValueError) as e: return 'Could not connect to Ollama (%s). Make sure it is running (ollama serve).' % e # Ask LLM to classify email and provide second opinion def get_llm_second_opinion(email_text, rf_spam_prob, model_name, domain_info=None): print('Running LLM second opinion...') truncated = email_text[:1000] + ('...' if len(email_text) > 1000 else '') domain_hint = "" if domain_info and domain_info.get('trusted'): domain_hint = f'\nIMPORTANT: This email is from {domain_info["domain"]} ({domain_info["match"]} domain). Government, military, and educational emails are almost always legitimate.\n' prompt = ( f'Classify this email as spam or ham (legitimate). Most emails are legitimate.\n' f'{domain_hint}\n' f'HAM examples: newsletters, order confirmations, bank alerts, church emails, subscription updates, gaming notifications, shipping alerts.\n' f'SPAM examples: prize scams, phishing for passwords, fake invoices, Nigerian prince schemes.\n\n' f'If unsure, choose ham. False positives are worse than false negatives.\n\n' f'Email: "{truncated}"\n\n' f'Reply with ONLY this JSON format:\n' f'{{"classification": "ham", "confidence": 0.85, "reason": "why"}}\n' f'or\n' f'{{"classification": "spam", "confidence": 0.95, "reason": "why"}}\n' f'/no_think' ) try: resp = requests.post(OLLAMA_API, json={ 'model': model_name, 'messages': [{'role': 'user', 'content': prompt}], 'stream': False, 'think': False, 'options': {'temperature': 0.1, 'num_predict': 150} }, timeout=60) if resp.status_code == 200: content = resp.json().get('message', {}).get('content', '') content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() # Extract JSON from response - try json.loads first, regex fallback data = None try: data = json.loads(content) except (json.JSONDecodeError, ValueError): json_match = re.search(r'\{.*\}', content, re.DOTALL) if json_match: try: data = json.loads(json_match.group()) except (json.JSONDecodeError, ValueError): pass if data: classification = data.get('classification', '').lower().strip() llm_conf = float(data.get('confidence', 0.5)) reason = data.get('reason', '') llm_spam_prob = llm_conf if classification == 'spam' else (1 - llm_conf) print('Done!') return {'spam_prob': llm_spam_prob, 'classification': classification, 'confidence': llm_conf, 'reason': reason} except Exception: pass return None # Apply theme colors to a matplotlib figure and axes def apply_figure_theme(fig, ax, theme): fig.patch.set_facecolor(theme['ax_facecolor']) ax.set_facecolor(theme['ax_facecolor']) ax.title.set_color(theme['text_color']) ax.xaxis.label.set_color(theme['text_color']) ax.yaxis.label.set_color(theme['text_color']) ax.tick_params(colors=theme['text_color']) for spine in ax.spines.values(): spine.set_edgecolor(theme['grid_color']) # Page config st.set_page_config(page_title="Spam Classifier + XAI", layout="wide", page_icon="@", initial_sidebar_state="expanded") # Restore state from localStorage on page refresh if 'results' not in st.session_state: if st.session_state.pop('_skip_restore', False): # Reset was just triggered — clear localStorage now (JS actually executes on this rerun) clear_local_storage() else: stored = load_from_local_storage() if stored: restored_results = _deserialize_results(stored.get('results')) if restored_results: st.session_state['results'] = restored_results if stored.get('email_input'): st.session_state['email_input'] = stored['email_input'] if 'dark_mode' in stored: st.session_state['dark_mode'] = stored['dark_mode'] if 'reset_counter' in stored: st.session_state['reset_counter'] = stored['reset_counter'] if 'threshold' in stored: st.session_state['saved_threshold'] = stored['threshold'] # Dark mode toggle setup if 'dark_mode' not in st.session_state: st.session_state['dark_mode'] = False dark_mode = st.session_state['dark_mode'] THEME = { 'spam_color': '#ff6b6b' if dark_mode else '#c62828', 'ham_color': '#69db7c' if dark_mode else '#2e7d32', 'spam_bg': '#3d1515' if dark_mode else '#ffebee', 'ham_bg': '#153d1a' if dark_mode else '#e8f5e9', 'metric_bg': '#2d2d2d' if dark_mode else '#f0f2f6', 'text_color': '#e0e0e0' if dark_mode else '#1a1a1a', 'ax_facecolor': '#1e1e1e' if dark_mode else '#ffffff', 'axvline_color': '#aaaaaa' if dark_mode else '#333333', 'bar_spam': '#ff6b6b' if dark_mode else '#e74c3c', 'bar_ham': '#69db7c' if dark_mode else '#2ecc71', 'bar_edge': '#666666' if dark_mode else '#333333', 'grid_color': '#444444' if dark_mode else '#cccccc', 'gauge_ham': '#69db7c' if dark_mode else '#4CAF50', 'gauge_spam': '#ff6b6b' if dark_mode else '#f44336', } # Build dark mode CSS string _dark_css = "" if dark_mode: _dark_css = """ /* Full-page dark mode overrides */ .stApp, .stApp > header, [data-testid='stAppViewContainer'], [data-testid='stHeader'], section[data-testid='stSidebar'], section[data-testid='stSidebar'] > div { background-color: #1a1a2e !important; color: #e0e0e0 !important; } section[data-testid='stSidebar'] { background-color: #16213e !important; } .stApp h1, .stApp h2, .stApp h3, .stApp h4, .stApp p, .stApp label, .stApp span, .stApp div, [data-testid='stMarkdownContainer'], [data-testid='stMarkdownContainer'] p, [data-testid='stMarkdownContainer'] h1, [data-testid='stMarkdownContainer'] h2, [data-testid='stMarkdownContainer'] h3, [data-testid='stMarkdownContainer'] h4 { color: #e0e0e0 !important; } div[data-testid='stMetric'] label, div[data-testid='stMetric'] [data-testid='stMetricValue'], div[data-testid='stMetric'] [data-testid='stMetricLabel'] { color: #e0e0e0 !important; } .stTabs [data-baseweb='tab'] { color: #b0b0b0 !important; } .stTabs [data-baseweb='tab'][aria-selected='true'] { color: #ff6b6b !important; } .stTextArea textarea, [data-baseweb='textarea'] textarea { background-color: #0f3460 !important; color: #e0e0e0 !important; border-color: #333 !important; } .stButton button:not([data-testid='stBaseButton-primary']) { background-color: #16213e !important; color: #e0e0e0 !important; border-color: #444 !important; } [data-testid='stExpander'] { background-color: #16213e !important; border-color: #333 !important; } hr { border-color: #333 !important; } [data-testid='stSlider'] p { color: #b0b0b0 !important; } pre, code, .stCodeBlock { color: #e0e0e0 !important; background-color: #1e1e2e !important; } .stJson { color: #e0e0e0 !important; background-color: #1e1e2e !important; } [data-testid='stExpander'] pre { background-color: #1e1e2e !important; } div[data-baseweb='popover'] > div, div[data-baseweb='tooltip'] > div { background-color: #16213e !important; color: #e0e0e0 !important; } """ st.markdown(""" """, unsafe_allow_html=True) # Load model try: model, vectorizer, feature_names, training_sample, raw_rf, saved_threshold, training_config, meta_scaler = load_model() except FileNotFoundError: st.error("Model files not found. Please run the notebook or `python retrain.py` first.") st.stop() # Sidebar with st.sidebar: st.toggle("Dark Mode", key='dark_mode') if st.button("New Classification", use_container_width=True): old_counter = st.session_state.get('reset_counter', 0) st.session_state.pop('results', None) st.session_state.pop('email_input', None) st.session_state.pop('feedback_given', None) st.session_state.pop('feedback_wrong', None) st.session_state.pop('feedback_saved', None) st.session_state.pop('feedback_msg', None) st.session_state.pop('email_textarea_%d' % old_counter, None) st.session_state['reset_counter'] = old_counter + 1 st.session_state['_skip_restore'] = True st.rerun() st.markdown("---") st.header("Spam Classifier + XAI") st.caption("ENGT 375 | Spring 2026 | ODU") st.markdown("---") # Classification threshold slider # The threshold slider lets users adjust how cautious the classifier is - # higher threshold = fewer false spam flags but might miss some real spam default_threshold = st.session_state.pop('saved_threshold', None) or 0.60 threshold = st.slider("Classification Threshold", 0.0, 1.0, float(default_threshold), 0.05, help="Emails with spam probability above this threshold are classified as spam. Default 0.60 to reduce false positives.") st.markdown("---") # Input mode selection input_mode = st.radio("Input Mode", ["Email Body Only", "Full Email with Headers"], help="Use 'Full Email with Headers' to get additional signals from email headers (List-Unsubscribe, SPF, DKIM).") st.markdown("---") # Example emails st.subheader("Example Emails") st.caption("Click to load a sample email") for name, text in EXAMPLE_EMAILS.items(): if st.button(name, key="ex_%s" % name, use_container_width=True): st.session_state['email_input'] = text st.markdown("---") # How it works with st.expander("How It Works"): st.markdown(""" **1. Text Input** - Paste email text directly, or upload an email screenshot (the app reads text from images) **2. Text Cleanup** - Strips out HTML code, web links, and email addresses - Converts everything to lowercase and reduces words to their root form (e.g., "running" -> "run") - Removes common filler words like "the", "is", "and" **3. Feature Extraction** - **Word Importance** (3000 words/phrases) — measures how important each word or phrase is compared to all emails the model has seen - **Email Patterns** (24 measurements) — things like how many exclamation marks, ALL CAPS words, links, dollar signs, and spam-like phrases appear - **AI Analysis** (6 scores, if available) — a language model rates the email's intent (promotional, personal, phishing) and tone (urgency, formality) **4. Decision Forest Classifier** - Hundreds of decision trees that each vote on spam vs. ham — the majority wins - Trained on ~70,000 emails from multiple sources - The model automatically adjusts for uneven spam/ham ratios and fine-tunes its confidence scores - Best settings found by testing many combinations automatically (GridSearchCV) **5. Domain Trust** - Emails from .gov, .mil, and .edu domains are given the benefit of the doubt - Trusted senders are capped at 30% spam probability max **6. Explainability (XAI)** - **LIME** — hides different words and watches how the prediction changes to find which words matter most - **SHAP** — calculates how much each word or feature pushed the result toward spam or toward ham, like a tug-of-war score - **ELI5** — looks inside the model to show which words it considers most important - **AI Explanation** — a language model summarizes the findings in everyday language """) # Ollama status qwen_models = check_ollama() if qwen_models: selected_model = st.selectbox("LLM Model", qwen_models, index=0) else: st.warning("Ollama not available - AI Explanation tab disabled") selected_model = None # Main area st.title("Spam Email Classifier with Explainable AI") st.markdown("Classify emails and understand **why** using LIME, SHAP, ELI5, and AI-powered explanations.") # Input text area default_text = st.session_state.get('email_input', '') counter = st.session_state.get('reset_counter', 0) email_text = st.text_area("Paste an email to classify:", value=default_text, height=200, placeholder="Paste email content here...", key="email_textarea_%d" % counter) classify_clicked = st.button("Classify", type="primary", use_container_width=True) if classify_clicked and email_text.strip(): # Compute and cache all results with st.spinner("Classifying and computing XAI explanations..."): print('Running classification...') # Check domain trust BEFORE preprocessing domain_info = check_domain_trust(email_text) # Extract header features if full email mode header_features = {} if input_mode == "Full Email with Headers": header_features = extract_header_features(email_text) clean = preprocess_text(email_text) tfidf_features = vectorizer.transform([clean]) meta_features = compute_metadata_features([email_text]) # Apply metadata scaler if available (matches training normalization) if meta_scaler is not None: meta_features = meta_scaler.transform(meta_features) # Add LLM features if model was trained with them if training_config.get('llm_features_used'): if selected_model: llm_feats = extract_llm_features_single(email_text, model_name=selected_model) else: llm_feats = np.full((1, 6), 0.5) # neutral fallback if LLM is offline # Combine TF-IDF + metadata + LLM features into one feature vector # (same order as during training so the model gets the right inputs) X_input = hstack([tfidf_features, csr_matrix(meta_features), csr_matrix(llm_feats)]).toarray() else: X_input = hstack([tfidf_features, csr_matrix(meta_features)]).toarray() # Prediction proba = model.predict_proba(X_input)[0] # LLM scoring removed — small model confidence was static (~92-95%) # LLM still used for AI Explanation tab via get_llm_explanation() llm_second_opinion = None llm_weight_used = 0.0 original_proba = proba.copy() # Domain trust cap - max 30% spam for trusted domains if domain_info.get('trusted'): if proba[1] > 0.30: proba = np.array([0.70, 0.30]) # Header-based post-classification adjustment if header_features: adjustment = 0.0 if header_features.get('has_list_unsubscribe'): adjustment = adjustment - 0.10 # Strong ham signal if header_features.get('sender_is_gov'): adjustment = adjustment - 0.15 # Government sender if header_features.get('has_spf_pass'): adjustment = adjustment - 0.03 if header_features.get('has_dkim_pass'): adjustment = adjustment - 0.03 if adjustment != 0.0: adjusted_spam = max(0.01, min(0.99, proba[1] + adjustment)) proba = np.array([1 - adjusted_spam, adjusted_spam]) # LIME explanation print('Running LIME explanation...') lime_training_data = training_sample if training_sample is not None else X_input explainer = get_lime_explainer(lime_training_data, feature_names) exp = explainer.explain_instance(X_input[0], raw_rf.predict_proba, num_features=10) label_to_explain = list(exp.as_map().keys())[0] feature_weights = exp.as_list(label=label_to_explain) print('Done!') # SHAP explanation - I have to use the raw RF here because SHAP's # TreeExplainer doesn't work with the calibrated model wrapper print('Running SHAP explanation...') shap_explainer = get_shap_explainer(raw_rf) shap_values = shap_explainer.shap_values(X_input, check_additivity=False) # SHAP returns different shapes depending on the version, so I have to # check which format it gives me (this took a while to debug) if isinstance(shap_values, list): sv = shap_values[1][0] elif shap_values.ndim == 3: sv = shap_values[0, :, 1] else: sv = shap_values[0] top_idx = np.argsort(np.abs(sv))[::-1][:10] print('Done!') # ELI5 explanation print('Running ELI5 explanation...') eli5_exp = eli5.explain_prediction(raw_rf, X_input[0], feature_names=feature_names, top=10) eli5_html = eli5.format_as_html(eli5_exp) eli5_text = eli5.format_as_text( eli5.explain_prediction(raw_rf, X_input[0], feature_names=feature_names, top=5) ) eli5_top_exp = eli5.explain_prediction(raw_rf, X_input[0], feature_names=feature_names, top=5) eli5_feat_names = [] if hasattr(eli5_top_exp, 'targets') and eli5_top_exp.targets: for fw in eli5_top_exp.targets[0].feature_weights.pos[:5]: eli5_feat_names.append(fw.feature) for fw in eli5_top_exp.targets[0].feature_weights.neg[:5]: eli5_feat_names.append(fw.feature) print('Done!') # AI Explanation (if Ollama available) # I send the XAI results to the LLM so it can explain them in plain English - # most people won't understand raw SHAP values or LIME weights ai_explanation = None if selected_model: lime_feats = [(f, w) for f, w in feature_weights[:5]] shap_feats = [(feature_names[i], float(sv[i])) for i in top_idx[:5]] prediction = 1 if proba[1] >= threshold else 0 label = "SPAM" if prediction == 1 else "HAM" confidence = proba[prediction] ai_explanation = get_llm_explanation( email_text, label, confidence, proba, lime_feats, shap_feats, eli5_feat_names, selected_model ) # Store everything in session state st.session_state['results'] = { 'email_text': email_text, 'clean': clean, 'meta': compute_metadata_features([email_text])[0], 'proba': proba, 'original_proba': original_proba, 'llm_second_opinion': llm_second_opinion, 'domain_info': domain_info, 'header_features': header_features, 'feature_weights': feature_weights, 'sv': sv, 'top_idx': top_idx, 'eli5_html': eli5_html, 'eli5_text': eli5_text, 'eli5_feat_names': eli5_feat_names, 'ai_explanation': ai_explanation, 'llm_weight_used': llm_weight_used, } # Persist to browser localStorage save_to_local_storage({ 'results': st.session_state['results'], 'email_input': email_text, 'dark_mode': st.session_state.get('dark_mode', False), 'reset_counter': st.session_state.get('reset_counter', 0), 'threshold': threshold, }) elif classify_clicked: st.warning("Please paste an email to classify.") # Display results from session state (survives reruns) if 'results' in st.session_state: r = st.session_state['results'] proba = r['proba'] sv = r['sv'] top_idx = r['top_idx'] feature_weights = r['feature_weights'] # Re-derive label from current threshold (so threshold slider changes take effect) prediction = 1 if proba[1] >= threshold else 0 label = "SPAM" if prediction == 1 else "HAM" confidence = proba[prediction] # Domain trust badge domain_info = r.get('domain_info', {}) if domain_info.get('trusted'): st.markdown( 'Trusted Domain: %s (%s)' % (domain_info["domain"], domain_info["match"]), unsafe_allow_html=True ) # Header features info header_features = r.get('header_features', {}) if header_features: header_signals = [] if header_features.get('has_list_unsubscribe'): header_signals.append("Has unsubscribe link (legitimate mailing lists include this)") if header_features.get('sender_is_gov'): header_signals.append("Government sender domain") if header_features.get('has_spf_pass'): header_signals.append("SPF passed (sender's server is authorized)") if header_features.get('has_dkim_pass'): header_signals.append("DKIM passed (email signature verified)") if header_signals: st.info("**Header signals:** %s" % ' | '.join(header_signals)) # Preprocessing visualization with st.expander("Preprocessing Steps", expanded=False): col_a, col_b = st.columns(2) with col_a: st.markdown("**Original Text** (first 300 chars)") st.code(r['email_text'][:300] + ('...' if len(r['email_text']) > 300 else ''), language='text') with col_b: st.markdown("**After Preprocessing** (first 300 chars)") st.code(r['clean'][:300] + ('...' if len(r['clean']) > 300 else ''), language='text') meta_names = ['Exclamation Mark Density', 'Dollar Sign Count', 'ALL CAPS Word Ratio', 'Spam Phrase Count', 'Ham Phrase Count', 'Spam vs. Ham Word Balance', 'Link Count', 'HTML Code Tags', 'Email Length (chars)', 'Avg Sentence Length', 'ALL CAPS Usage', 'Mentions a Specific Date', 'Mentions a Specific Time', 'Date Reference Count', 'Has Unsubscribe Link', 'Has Physical Address', 'Has Proper Greeting', 'Has Contact Info', 'Sign-up/Register Language Score', 'Sales Language vs. Info Ratio', 'Shortened Link Usage', 'Known Platform Links', 'Gov/Edu Link Count', 'Question Mark Count'] st.markdown("**Extracted Metadata Features:**") meta_df = pd.DataFrame({'Feature': meta_names, 'Value': r['meta']}) st.dataframe(meta_df, use_container_width=True, hide_index=True) # Display result st.markdown("---") col1, col2, col3 = st.columns([1.5, 1, 1.5]) with col1: css_class = "spam-result" if prediction == 1 else "ham-result" st.markdown('
%s
' % (css_class, label), unsafe_allow_html=True) with col2: st.metric("Confidence", "%.1f%%" % (confidence * 100)) st.metric("Threshold", "%.0f%%" % (threshold * 100)) with col3: # Confidence gauge using a horizontal bar fig_gauge, ax_gauge = plt.subplots(figsize=(4, 1.8)) ax_gauge.barh([0], [proba[0]], color=THEME['gauge_ham'], height=0.5, label='Ham') ax_gauge.barh([0], [proba[1]], left=[proba[0]], color=THEME['gauge_spam'], height=0.5, label='Spam') ax_gauge.axvline(x=threshold, color=THEME['axvline_color'], linestyle='--', linewidth=2, label='Threshold (%.0f%%)' % (threshold * 100)) ax_gauge.set_xlim(0, 1) ax_gauge.set_yticks([]) ax_gauge.legend(loc='upper center', ncol=3, fontsize=6, bbox_to_anchor=(0.5, 1.4), facecolor=THEME['ax_facecolor'], labelcolor=THEME['text_color']) apply_figure_theme(fig_gauge, ax_gauge, THEME) fig_gauge.subplots_adjust(top=0.65, bottom=0.15) st.pyplot(fig_gauge, transparent=not dark_mode) plt.close() # XAI Explanations st.markdown("---") st.subheader("Explainable AI Analysis") tab_names = ["LIME", "SHAP", "ELI5", "Comparison"] if selected_model: tab_names.append("AI Explanation") tabs = st.tabs(tab_names) # LIME Tab with tabs[0]: st.markdown("#### LIME — What Words Mattered?") st.caption("LIME hides different words in the email and watches how the prediction changes — this reveals which words matter most.") features_sorted = sorted(feature_weights, key=lambda x: x[1]) names_lime = [f[0] for f in features_sorted] weights_lime = [f[1] for f in features_sorted] fig, ax = plt.subplots(figsize=(8, 5)) colors_bar = [THEME['bar_spam'] if w > 0 else THEME['bar_ham'] for w in weights_lime] ax.barh(names_lime, weights_lime, color=colors_bar, edgecolor=THEME['bar_edge'], alpha=0.85) ax.axvline(x=0, color=THEME['axvline_color'], linewidth=0.8) ax.set_title('LIME: Feature Contributions to Classification', fontsize=13, fontweight='bold') ax.set_xlabel('Bars pointing right \u2192 pushes toward SPAM | Bars pointing left \u2192 pushes toward HAM') apply_figure_theme(fig, ax, THEME) plt.tight_layout() st.pyplot(fig, transparent=not dark_mode) plt.close() # SHAP Tab with tabs[1]: st.markdown("#### SHAP — Tug-of-War Scores") st.caption("SHAP calculates how much each word or feature pushed the result toward spam or toward ham — like a tug-of-war score.") top_features_shap = [feature_names[i] for i in top_idx] top_vals_shap = sv[top_idx] fig, ax = plt.subplots(figsize=(8, 5)) colors_bar = [THEME['bar_spam'] if v > 0 else THEME['bar_ham'] for v in top_vals_shap] ax.barh(top_features_shap[::-1], top_vals_shap[::-1], color=colors_bar[::-1], edgecolor=THEME['bar_edge'], alpha=0.85) ax.axvline(x=0, color=THEME['axvline_color'], linewidth=0.8) ax.set_title('SHAP: Top Feature Contributions', fontsize=13, fontweight='bold') ax.set_xlabel('Bars pointing right \u2192 pushes toward SPAM | Bars pointing left \u2192 pushes toward HAM') apply_figure_theme(fig, ax, THEME) plt.tight_layout() st.pyplot(fig, transparent=not dark_mode) plt.close() # ELI5 Tab with tabs[2]: st.markdown("#### ELI5 — Model's Own Rankings") st.caption("ELI5 looks inside the model to show which words it considers most important for its decision.") eli5_html = r['eli5_html'] if dark_mode: eli5_html = '
%s
' % eli5_html st.components.v1.html(eli5_html, height=400, scrolling=True) if dark_mode: st.caption("ELI5 uses its own styling - shown with light background for readability.") # Comparison Tab with tabs[3]: st.markdown("#### Side-by-Side Comparison") st.caption("If multiple tools agree a word is important, that's a stronger signal it actually matters.") col1, col2, col3 = st.columns(3) with col1: st.markdown("##### LIME Top 5") for feat, w in feature_weights[:5]: direction = "spam" if w > 0 else "ham" color = THEME['spam_color'] if w > 0 else THEME['ham_color'] display_feat = feat[:20] + "..." if len(feat) > 20 else feat st.markdown("- `%s` -> %s (%+.3f)" % (color, display_feat, direction, w), unsafe_allow_html=True) with col2: st.markdown("##### SHAP Top 5") for i in top_idx[:5]: direction = "spam" if sv[i] > 0 else "ham" color = THEME['spam_color'] if sv[i] > 0 else THEME['ham_color'] display_feat = feature_names[i][:20] + "..." if len(feature_names[i]) > 20 else feature_names[i] st.markdown("- `%s` -> %s (%+.3f)" % (color, display_feat, direction, sv[i]), unsafe_allow_html=True) with col3: st.markdown("##### ELI5 Top 5") for feat_name in r['eli5_feat_names'][:5]: display_feat = feat_name[:20] + "..." if len(feat_name) > 20 else feat_name st.markdown("- `%s`" % display_feat) # Feature agreement analysis st.markdown("---") st.markdown("##### Feature Agreement") lime_top = set(f[0] for f in feature_weights[:10]) shap_top = set(feature_names[i] for i in top_idx[:10]) overlap = lime_top & shap_top st.markdown("**LIME-SHAP overlap** (top 10): **%d** shared features" % len(overlap)) if overlap: st.markdown("Shared: %s" % ', '.join('`%s`' % f for f in sorted(overlap))) # AI Explanation Tab if selected_model: with tabs[4]: st.markdown("#### AI-Powered Explanation") st.caption("Using **%s** via Ollama to explain the classification in plain English." % selected_model) if r['ai_explanation']: st.info(r['ai_explanation']) else: st.warning("AI explanation was not generated for this classification.") lime_feats = [(f, w) for f, w in feature_weights[:5]] shap_feats = [(feature_names[i], float(sv[i])) for i in top_idx[:5]] with st.expander("XAI Data Sent to LLM"): st.json({ 'prediction': label, 'confidence': "%.1f%%" % (confidence * 100), 'ham_prob': "%.1f%%" % (proba[0] * 100), 'spam_prob': "%.1f%%" % (proba[1] * 100), 'lime_top5': [{'feature': f, 'weight': round(w, 4)} for f, w in lime_feats], 'shap_top5': [{'feature': f, 'value': round(v, 4)} for f, v in shap_feats], 'eli5_top5': r['eli5_feat_names'][:5] }) # Feedback section # The feedback system lets users correct wrong predictions, and those corrections # get saved so the model can learn from its mistakes when retrained st.markdown("---") st.markdown("### Was this classification correct?") if not st.session_state.get('feedback_given'): fb_col1, fb_col2 = st.columns(2) with fb_col1: if st.button("Yes, correct!", use_container_width=True, type="primary"): prediction = 1 if r['proba'][1] >= threshold else 0 predicted_label = "spam" if prediction == 1 else "ham" count = save_feedback(r['email_text'], predicted_label, predicted_label, r['proba'][1], '', 'correct') st.session_state['feedback_given'] = True st.session_state['feedback_msg'] = ( "Feedback saved! Classified as %s confirmed correct. %d entries logged so far." % (predicted_label.upper(), count)) st.rerun() with fb_col2: if st.button("No, it's wrong!", use_container_width=True): st.session_state['feedback_wrong'] = True st.session_state['feedback_given'] = True st.rerun() if st.session_state.get('feedback_wrong') and not st.session_state.get('feedback_saved'): prediction = 1 if r['proba'][1] >= threshold else 0 predicted_label = "spam" if prediction == 1 else "ham" correct_label = "ham" if predicted_label == "spam" else "spam" st.warning("Predicted: **%s** | Correct: **%s**" % (predicted_label.upper(), correct_label.upper())) user_notes = st.text_input("Optional notes (e.g., 'Steam notification'):", key="feedback_notes") if st.button("Save Correction", type="primary"): count = save_feedback(r['email_text'], predicted_label, correct_label, r['proba'][1], user_notes, 'incorrect') st.session_state['feedback_saved'] = True st.session_state['feedback_msg'] = ( "Feedback saved! Predicted %s → Correct %s. %d entries logged so far." % (predicted_label.upper(), correct_label.upper(), count)) st.rerun() if st.session_state.get('feedback_msg'): st.success(st.session_state['feedback_msg']) # Footer st.markdown("---") st.caption("ENGT 375 Project - Spam Classification with XAI | Random Forest + LIME + SHAP + ELI5 + Ollama LLM")