spam-xai-model-v2 / utils.py
VoltageVagabond's picture
Upload folder using huggingface_hub
6673246 verified
raw
history blame
18.5 kB
# Shared utilities for the spam classifier project
# ENGT 375 Project - Spring 2026 - ODU
# I put the shared functions here so I don't have to copy-paste them
# between the training script and the Gradio app
import re
import numpy as np
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
# I came up with these phrase lists by looking at common spam patterns
# The idea is that certain phrases are strong signals for spam vs ham
spam_context_phrases = [
'act now', 'limited time', 'click to claim', 'you have won',
'wire transfer', 'bank account', 'million dollar', 'free gift',
'no prescription', 'buy now', 'make money fast', 'lose weight',
'casino', 'free credit'
]
ham_context_phrases = [
'click to unsubscribe', 'unsubscribe from', 'to opt out',
'this newsletter', 'you are receiving this', 'official notice',
'department of', 'office of', 'subscribe to updates',
'manage your subscription', 'privacy policy', 'government website',
'register now', 'sign up', 'reserve your spot', 'rsvp',
'event details', 'schedule', 'agenda', 'venue',
'annual', 'edition', 'season', 'community'
]
# Registration/event language (ham signal for community emails)
registration_phrases = [
'register now', 'sign up', 'reserve your spot', 'rsvp',
'registration open', 'tickets available', 'limited capacity',
'early bird', 'reserve your seat', 'join us'
]
# URL shortener domains (spam signal)
url_shorteners = [
'bit.ly', 'tinyurl.com', 'goo.gl', 't.co', 'ow.ly',
'is.gd', 'buff.ly', 'adf.ly', 'shorte.st'
]
# Legitimate platform domains (ham signal)
legitimate_platforms = [
'eventbrite.com', 'meetup.com', 'mailchimp.com', 'constantcontact.com',
'surveymonkey.com', 'google.com', 'zoom.us', 'microsoft.com',
'linkedin.com', 'facebook.com', 'github.com', 'youtube.com',
'steampowered.com', 'store.steampowered.com', 'paypal.com', 'stripe.com',
'shopify.com', 'etsy.com', 'bestbuy.com', 'target.com',
'amazon.com', 'netflix.com', 'spotify.com'
]
# The 24 metadata feature names, in the same order compute_metadata_features returns them
META_FEATURE_NAMES = [
'exclamation_density', 'dollar_sign_count', 'caps_word_ratio',
'spam_phrase_count', 'ham_phrase_count', 'net_spam_context',
'url_count', 'html_tag_count', 'email_length',
'avg_sentence_length', 'capitalization_ratio',
'has_specific_date', 'has_specific_time', 'date_reference_count',
'has_unsubscribe', 'has_physical_address', 'has_proper_greeting',
'has_contact_info', 'registration_language_score',
'cta_to_info_ratio', 'shortener_url_ratio',
'legitimate_platform_count', 'gov_edu_url_count',
'question_mark_count',
]
# Human-readable descriptions for each feature (used in the XAI explanation tabs)
FEATURE_DESCRIPTIONS = {
'exclamation_density': 'Number of exclamation marks per sentence',
'dollar_sign_count': 'Number of dollar signs in the email',
'caps_word_ratio': 'Fraction of words that are ALL CAPS (2+ letters)',
'spam_phrase_count': 'Count of known spam phrases found in the email',
'ham_phrase_count': 'Count of known ham (legitimate) phrases found in the email',
'net_spam_context': 'Spam phrase count minus ham phrase count',
'url_count': 'Number of URLs in the email',
'html_tag_count': 'Number of HTML tags in the email',
'email_length': 'Total character length of the email',
'avg_sentence_length': 'Average number of characters per sentence',
'capitalization_ratio': 'Fraction of alphabetic characters that are uppercase',
'has_specific_date': 'Whether the email mentions a specific date (1 or 0)',
'has_specific_time': 'Whether the email mentions a specific time with AM/PM (1 or 0)',
'date_reference_count': 'Number of month name references in the email',
'has_unsubscribe': 'Whether the email contains unsubscribe or opt-out language (1 or 0)',
'has_physical_address': 'Whether the email contains a street address (1 or 0)',
'has_proper_greeting': 'Whether the email starts with a proper greeting like Dear/Hello (1 or 0)',
'has_contact_info': 'Whether the email contains a phone number or email address (1 or 0)',
'registration_language_score': 'Count of event registration phrases found',
'cta_to_info_ratio': 'Ratio of call-to-action words to informational words',
'shortener_url_ratio': 'Fraction of URLs that use URL shortener services',
'legitimate_platform_count': 'Number of URLs from known legitimate platforms',
'gov_edu_url_count': 'Number of URLs ending in .gov or .edu',
'question_mark_count': 'Number of question marks in the email',
}
stemmer = PorterStemmer()
stop_words = set(stopwords.words('english'))
# Clean and stem the input text (Prof. Kuzlu showed us stemming in class)
def preprocess_text(text):
# Remove HTML tags
text = re.sub(r'<[^>]+>', ' ', text)
# Remove URLs
text = re.sub(r'https?://\S+|www\.\S+', ' ', text)
# Remove email addresses
text = re.sub(r'\S+@\S+', ' ', text)
# Remove non-letter characters (keep spaces)
text = re.sub(r'[^a-zA-Z\s]', ' ', text)
# Lowercase everything
text = text.lower()
# Split into words, remove stopwords, stem the rest
tokens = text.split()
result = []
for w in tokens:
if w not in stop_words and len(w) > 2:
result.append(stemmer.stem(w))
return ' '.join(result)
# Compute 24 metadata features from a list of email texts
# I designed these features based on what I noticed about spam vs ham emails -
# things like exclamation marks, dollar signs, and ALL CAPS words
def compute_metadata_features(texts):
features = []
for text in texts:
# Sentence count (split on . ! ?)
sentences = re.split(r'[.!?]+', text)
sentence_count = max(len([s for s in sentences if s.strip()]), 1)
# 1. exclamation_density (per sentence, not raw count)
exclamation_density = text.count('!') / sentence_count
# 2. dollar_sign_count
dollar_count = text.count('$')
# 3. caps_word_ratio (proportion of ALL-CAPS words)
words = text.split()
caps_words = []
for w in words:
if w.isupper() and len(w) > 1:
caps_words.append(w)
caps_word_ratio = len(caps_words) / max(len(words), 1)
# 4. spam_phrase_count
text_lower = text.lower()
spam_phrase_count = 0
for p in spam_context_phrases:
if p in text_lower:
spam_phrase_count = spam_phrase_count + 1
# 5. ham_phrase_count
ham_phrase_count = 0
for p in ham_context_phrases:
if p in text_lower:
ham_phrase_count = ham_phrase_count + 1
# 6. net_spam_context
net_spam_context = spam_phrase_count - ham_phrase_count
# 7. url_count
url_count = len(re.findall(r'https?://\S+|www\.\S+', text))
# 8. html_tag_count
html_tag_count = len(re.findall(r'<[^>]+>', text))
# 9. email_length
email_length = len(text)
# 10. avg_sentence_length
avg_sentence_length = len(text) / sentence_count
# 11. capitalization_ratio (char-level)
alpha_chars = []
for c in text:
if c.isalpha():
alpha_chars.append(c)
upper_count = 0
for c in alpha_chars:
if c.isupper():
upper_count = upper_count + 1
cap_ratio = upper_count / max(len(alpha_chars), 1)
# --- NEW FEATURES (12-24) ---
# 12. has_specific_date (day-of-week or month+day pattern)
date_patterns = [
r'\b(?:Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday)\b',
r'\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2}',
r'\b\d{1,2}/\d{1,2}/\d{2,4}\b',
r'\b\d{4}-\d{2}-\d{2}\b'
]
has_specific_date = 0
for pat in date_patterns:
if re.search(pat, text, re.IGNORECASE):
has_specific_date = 1
break
# 13. has_specific_time (time with AM/PM)
has_specific_time = 1 if re.search(r'\b\d{1,2}:\d{2}\s*(?:AM|PM|am|pm|a\.m\.|p\.m\.)\b', text) else 0
# 14. date_reference_count (count of month name references)
months = ['january', 'february', 'march', 'april', 'may', 'june',
'july', 'august', 'september', 'october', 'november', 'december']
date_reference_count = 0
for m in months:
date_reference_count = date_reference_count + len(re.findall(r'\b' + m + r'\b', text_lower))
# 15. has_unsubscribe
has_unsubscribe = 1 if re.search(r'unsubscribe|opt.out', text_lower) else 0
# 16. has_physical_address (street address pattern)
has_physical_address = 1 if re.search(r'\d+\s+\w+\s+(?:St|Street|Ave|Avenue|Blvd|Boulevard|Dr|Drive|Rd|Road|Ln|Lane|Way|Ct|Court)\b', text, re.IGNORECASE) else 0
# 17. has_proper_greeting (starts with Dear/Hello/Hi + name)
has_proper_greeting = 1 if re.search(r'^(?:Dear|Hello|Hi|Good morning|Good afternoon)\s+\w', text.strip(), re.IGNORECASE) else 0
# 18. has_contact_info (phone number or email in body)
has_phone = bool(re.search(r'\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}', text))
has_email_addr = bool(re.search(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', text))
has_contact_info = 1 if (has_phone or has_email_addr) else 0
# 19. registration_language_score
registration_language_score = 0
for phrase in registration_phrases:
if phrase in text_lower:
registration_language_score = registration_language_score + 1
# 20. cta_to_info_ratio (call-to-action words vs informational words)
cta_words = ['buy', 'order', 'click', 'act', 'hurry', 'rush', 'grab', 'claim']
info_words = ['schedule', 'agenda', 'details', 'information', 'about', 'learn',
'location', 'venue', 'date', 'time', 'address', 'contact']
cta_count = 0
for w in cta_words:
cta_count = cta_count + text_lower.split().count(w)
info_count = 0
for w in info_words:
info_count = info_count + text_lower.split().count(w)
cta_to_info_ratio = cta_count / max(info_count, 1)
# 21. shortener_url_ratio (fraction of URLs using shorteners)
urls = re.findall(r'https?://([^\s/]+)', text)
shortener_count = 0
for u in urls:
for s in url_shorteners:
if s in u.lower():
shortener_count = shortener_count + 1
break
shortener_url_ratio = shortener_count / max(len(urls), 1)
# 22. legitimate_platform_count (URLs from known platforms)
legitimate_platform_count = 0
for u in urls:
for p in legitimate_platforms:
if p in u.lower():
legitimate_platform_count = legitimate_platform_count + 1
break
# 23. gov_edu_url_count (URLs ending in .gov or .edu)
gov_edu_url_count = 0
for u in urls:
if u.lower().endswith('.gov') or u.lower().endswith('.edu'):
gov_edu_url_count = gov_edu_url_count + 1
# 24. question_mark_count (questions suggest conversation, ham signal)
question_mark_count = text.count('?')
features.append([exclamation_density, dollar_count, caps_word_ratio,
spam_phrase_count, ham_phrase_count, net_spam_context,
url_count, html_tag_count, email_length,
avg_sentence_length, cap_ratio,
has_specific_date, has_specific_time, date_reference_count,
has_unsubscribe, has_physical_address, has_proper_greeting,
has_contact_info, registration_language_score,
cta_to_info_ratio, shortener_url_ratio,
legitimate_platform_count, gov_edu_url_count,
question_mark_count])
return np.array(features)
# ---------------------------------------------------------------------------
# XAI helper functions (used by the student notebook)
# I put these here to keep the notebook short and easy to read.
# Each function is written in the same beginner style as the course lectures:
# plain for-loops, named variables, and comments that explain the "why".
# ---------------------------------------------------------------------------
def plot_push_direction(top10, title, save_path=None):
"""Bar chart showing which features push the model toward spam vs ham.
Parameters
----------
top10 : list of (feature_name, weight) tuples
Usually the top-10 list produced by LIME / SHAP / ELI5.
title : str
Title to put on the chart.
save_path : str or Path, optional
If given, the figure is saved here (e.g. FIGURES_DIR / 'lime.png').
Positive weights are drawn in steel blue (push toward spam).
Negative weights are drawn in salmon (push toward ham).
"""
import matplotlib.pyplot as plt
# Split the (name, weight) tuples into two parallel lists
feature_names = []
feature_weights = []
for feat, weight in top10:
feature_names.append(feat)
feature_weights.append(weight)
# Pick a color for each bar based on the sign of the weight
bar_colors = []
for weight in feature_weights:
if weight >= 0:
bar_colors.append('steelblue') # pushes prediction toward spam
else:
bar_colors.append('salmon') # pushes prediction toward ham
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(feature_names, feature_weights, color=bar_colors)
# The zero line makes it easy to see which side a feature is on
ax.axvline(x=0, color='black', linestyle='--', linewidth=1)
ax.set_xlabel('Feature Weight (positive -> spam, negative -> ham)', fontsize=12)
ax.set_ylabel('Feature', fontsize=12)
ax.set_title(title, fontsize=14)
ax.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
if save_path is not None:
plt.savefig(save_path)
plt.show()
def compare_top10_agreement(lime_top10, shap_top10, eli5_top10):
"""Print which features show up in the top-10 lists of all three tools.
This answers the question: do LIME, SHAP, and ELI5 agree on what matters?
Returns nothing — just prints a report.
"""
# Pull just the feature names (drop the importance scores)
lime_names = set(feat for feat, score in lime_top10)
shap_names = set(feat for feat, score in shap_top10)
eli5_names = set(feat for feat, score in eli5_top10)
print('=== Top-10 Feature Agreement ===')
print('\nLIME Top-10:', sorted(lime_names))
print('SHAP Top-10:', sorted(shap_names))
print('ELI5 Top-10:', sorted(eli5_names))
# Intersect the sets to see what the tools agree on
all_three = lime_names & shap_names & eli5_names
lime_shap = lime_names & shap_names
lime_eli5 = lime_names & eli5_names
shap_eli5 = shap_names & eli5_names
print('\nAgreed by ALL THREE: %s (%d/10)' % (sorted(all_three), len(all_three)))
print('LIME-SHAP: %s (%d/10)' % (sorted(lime_shap), len(lime_shap)))
print('LIME-ELI5: %s (%d/10)' % (sorted(lime_eli5), len(lime_eli5)))
print('SHAP-ELI5: %s (%d/10)' % (sorted(shap_eli5), len(shap_eli5)))
def feature_reduction_experiment(feature_importances, X_train_dense, X_test_dense,
y_train, y_test, baseline_model,
baseline_accuracy, baseline_f1,
reduction_pcts=(10, 20, 30), random_state=42):
"""Drop the least-important features, retrain, and see what happens to accuracy.
This is the Kuzlu et al. (2020) methodology: if we remove the features that
the XAI tool says don't matter, does the model still work?
Parameters
----------
feature_importances : 1-D array
Importance score per feature (e.g. mean |SHAP|, permutation importance,
or RF's built-in feature_importances_).
X_train_dense, X_test_dense : 2-D arrays
Full (non-reduced) feature matrices.
y_train, y_test : 1-D arrays
Labels.
baseline_model : fitted RandomForestClassifier
Used only to copy hyperparameters (n_estimators, max_depth) so the
retrained models match the original.
baseline_accuracy, baseline_f1 : float
Accuracy and F1 from the original (full-feature) model.
reduction_pcts : iterable of int
Percentages of features to drop, e.g. [10, 20, 30].
random_state : int
For reproducibility.
Returns
-------
pandas.DataFrame
One row per reduction percentage (including 0% = baseline).
"""
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
imp = np.array(feature_importances).flatten()
# Start the results list with the baseline row (no reduction)
results = [{
'Reduction %': 0,
'Features': X_train_dense.shape[1],
'Accuracy': baseline_accuracy,
'F1': baseline_f1,
}]
# Sort feature indices from least important to most important
sorted_idx = np.argsort(imp)
for pct in reduction_pcts:
# Figure out how many features to drop, then keep the rest
n_remove = int(len(imp) * pct / 100)
keep_idx = sorted_idx[n_remove:]
X_train_reduced = X_train_dense[:, keep_idx]
X_test_reduced = X_test_dense[:, keep_idx]
# Retrain a new Random Forest on the reduced feature set.
# We copy the hyperparameters from the baseline so it's a fair comparison.
rf_reduced = RandomForestClassifier(
n_estimators=baseline_model.n_estimators,
max_depth=baseline_model.max_depth,
random_state=random_state
)
rf_reduced.fit(X_train_reduced, y_train)
y_pred_reduced = rf_reduced.predict(X_test_reduced)
results.append({
'Reduction %': pct,
'Features': len(keep_idx),
'Accuracy': accuracy_score(y_test, y_pred_reduced),
'F1': f1_score(y_test, y_pred_reduced),
})
return pd.DataFrame(results)