spam-xai-model-v2 / retrain.py
VoltageVagabond's picture
Upload folder using huggingface_hub
960ec3d verified
Raw
History Blame
10.4 kB
# Batch retrain script for the spam classifier.
#
# Two modes:
# fast - single Random Forest, smaller TF-IDF (1000 features), no GridSearchCV.
# Takes ~5-10 minutes. Use this when you just want to verify the pipeline
# works after a small change.
# full - Voting ensemble (RF + LR + SVM), full TF-IDF (3000 features),
# uses all your data including feedback corrections.
# Takes ~15-30 minutes. This is the production model.
#
# Reads feedback corrections from data/feedback/feedback_log.csv (if --no-feedback
# is not passed) and merges them into the training data with 5x weighting.
#
# Usage:
# python3 retrain.py --mode fast # quick smoke-test retrain
# python3 retrain.py --mode full # production retrain
# python3 retrain.py --mode full --no-feedback # full retrain, ignore user feedback
import sys
import csv
import argparse
import warnings
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report, precision_recall_curve
from scipy.sparse import hstack, csr_matrix
import joblib
from utils import preprocess_text, compute_metadata_features, META_FEATURE_NAMES
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
project_dir = Path(__file__).parent
data_dir = project_dir / 'data'
models_dir = project_dir / 'models'
# After the merge, feedback lives under data/feedback/
feedback_csv = project_dir / 'data' / 'feedback' / 'feedback_log.csv'
random_state = 42
def load_feedback_corrections():
"""Read feedback CSV and return a DataFrame of corrections."""
if not feedback_csv.exists():
print("No feedback file found.")
return pd.DataFrame(columns=['text', 'label'])
corrections = []
with open(feedback_csv, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
if row.get('feedback') == 'wrong' and row.get('correct_label'):
label = 1 if row['correct_label'].lower() == 'spam' else 0
corrections.append({
'text': row['email_text'],
'label': label,
})
df = pd.DataFrame(corrections)
print(f"Found {len(df)} corrections in feedback log.")
return df
def main():
parser = argparse.ArgumentParser(description='Retrain the spam classifier')
parser.add_argument('--mode', choices=['fast', 'full'], default='full',
help='fast = single RF / smaller TF-IDF (~2-5 min); full = voting ensemble (~15-30 min)')
parser.add_argument('--no-feedback', action='store_true',
help='Ignore user feedback corrections from data/feedback/feedback_log.csv')
args = parser.parse_args()
mode = args.mode
use_feedback = not args.no_feedback
print(f"=== Retrain mode: {mode.upper()} ===")
if mode == 'fast':
print("Fast mode: single RandomForest, 1000 TF-IDF features, no grid search")
print("Expected runtime: ~5-10 minutes")
else:
print("Full mode: VotingClassifier ensemble (RF + LR + SVM), 3000 TF-IDF features")
print("Expected runtime: ~15-30 minutes")
print()
print("Loading training data...")
enron_path = data_dir / 'raw' / 'enron' / 'enron_spam_data.csv'
puyang_path = data_dir / 'raw' / 'puyang2025' / 'seven_phishing_emails.parquet'
zefang_path = data_dir / 'raw' / 'zefang' / 'phishing_emails.parquet'
frames = []
# Enron corpus — gold standard real corporate email (~33k)
if enron_path.exists():
enron_df = pd.read_csv(enron_path)
enron_df = enron_df.rename(columns={'Message': 'text', 'Spam/Ham': 'label_str'})
enron_df['label'] = enron_df['label_str'].str.lower().map({'spam': 1, 'ham': 0})
enron_df = enron_df[['text', 'label']].dropna()
frames.append(enron_df)
print(f" Enron: {len(enron_df)} emails")
# puyang2025 — 7 research corpora in one parquet (TREC-05/06/07, CEAS-08,
# Enron subset, SpamAssassin, Ling-Spam). Drop the Enron sub-corpus to avoid
# duplicating what we already loaded above.
if puyang_path.exists():
puyang_df = pd.read_parquet(puyang_path)
puyang_df = puyang_df[puyang_df['dataset_name'] != 'Enron']
puyang_df = puyang_df.rename(columns={'label': 'label_int'})
puyang_df['label'] = puyang_df['label_int'].map({0: 0, 1: 1})
puyang_df = puyang_df[['text', 'label']].dropna()
frames.append(puyang_df)
print(f" puyang2025 (7 corpora, Enron excluded): {len(puyang_df)} emails")
# zefang phishing dataset — 18k emails labeled ham vs phishing
# Map phishing -> 1 (treat as spam class) and ham -> 0
if zefang_path.exists():
zefang_df = pd.read_parquet(zefang_path)
zefang_df['label'] = zefang_df['label'].map({'ham': 0, 'phishing': 1})
zefang_df = zefang_df[['text', 'label']].dropna()
frames.append(zefang_df)
print(f" zefang phishing: {len(zefang_df)} emails")
if not frames:
print("ERROR: No training data found in data/ directory.")
sys.exit(1)
df = pd.concat(frames, ignore_index=True)
print(f" Total original: {len(df)} emails")
if use_feedback:
feedback_df = load_feedback_corrections()
if len(feedback_df) > 0:
df = pd.concat([df, feedback_df], ignore_index=True)
print(f" After feedback merge: {len(df)} emails")
print(f"Preprocessing {len(df)} emails (parallel)...")
texts = df['text'].tolist()
chunk_size = max(1, len(texts) // 12)
chunks = [texts[i:i+chunk_size] for i in range(0, len(texts), chunk_size)]
def preprocess_chunk(chunk):
return [preprocess_text(t) for t in chunk]
results = joblib.Parallel(n_jobs=-1, prefer='processes')(
joblib.delayed(preprocess_chunk)(c) for c in chunks
)
cleaned = [item for sublist in results for item in sublist]
print(f" Done — {len(cleaned)} emails preprocessed")
df['clean'] = cleaned
df = df[df['clean'].str.len() > 0]
X_text = df['clean'].to_numpy()
y = df['label'].to_numpy()
X_train_text, X_test_text, y_train, y_test = train_test_split(
X_text, y, test_size=0.2, random_state=random_state, stratify=y
)
# Smaller TF-IDF in fast mode for speed
if mode == 'fast':
max_feats = 1000
ngrams = (1, 2)
else:
max_feats = 3000
ngrams = (1, 3)
print(f"Fitting TF-IDF (max_features={max_feats}, ngram_range={ngrams})...")
tfidf = TfidfVectorizer(max_features=max_feats, ngram_range=ngrams,
min_df=2, max_df=0.95)
X_train_tfidf = tfidf.fit_transform(X_train_text)
X_test_tfidf = tfidf.transform(X_test_text)
print("Computing metadata features...")
# We need to get the original text (not cleaned) for metadata features
# Use index alignment with the split
train_orig = df.loc[df['clean'].isin(X_train_text), 'text'].to_numpy()[:len(X_train_text)]
test_orig = df.loc[df['clean'].isin(X_test_text), 'text'].to_numpy()[:len(X_test_text)]
X_train_meta = compute_metadata_features(train_orig.tolist())
X_test_meta = compute_metadata_features(test_orig.tolist())
scaler = MinMaxScaler()
X_train_meta_scaled = scaler.fit_transform(X_train_meta)
X_test_meta_scaled = scaler.transform(X_test_meta)
X_train = hstack([X_train_tfidf, csr_matrix(X_train_meta_scaled)])
X_test = hstack([X_test_tfidf, csr_matrix(X_test_meta_scaled)])
feature_names_list = tfidf.get_feature_names_out().tolist() + META_FEATURE_NAMES
if mode == 'fast':
print("Training single RandomForest (fast mode)...")
ensemble = RandomForestClassifier(
n_estimators=50, n_jobs=-1,
class_weight='balanced', random_state=random_state)
else:
print("Training VotingClassifier ensemble (full mode)...")
ensemble = VotingClassifier(
estimators=[
('rf', RandomForestClassifier(
n_estimators=200, n_jobs=-1,
class_weight='balanced', random_state=random_state)),
('lr', LogisticRegression(
max_iter=1000, class_weight='balanced', random_state=random_state)),
('svm', CalibratedClassifierCV(
LinearSVC(class_weight='balanced', max_iter=2000,
random_state=random_state))),
],
voting='soft',
)
ensemble.fit(X_train, y_train)
y_pred = ensemble.predict(X_test)
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['Ham', 'Spam']))
y_scores = ensemble.predict_proba(X_test)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_test, y_scores)
f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
best_idx = np.argmax(f1_scores)
optimal_threshold = float(thresholds[best_idx])
print(f"Optimal threshold: {optimal_threshold:.4f}")
models_dir.mkdir(exist_ok=True)
joblib.dump(ensemble, models_dir / 'voting_model.joblib')
joblib.dump(tfidf, models_dir / 'tfidf_vectorizer.joblib')
joblib.dump(scaler, models_dir / 'meta_scaler.joblib')
joblib.dump(feature_names_list, models_dir / 'feature_names.joblib')
joblib.dump(optimal_threshold, models_dir / 'optimal_threshold.joblib')
sample_size = min(200, X_train.shape[0])
sample_idx = np.random.RandomState(random_state).choice(
X_train.shape[0], sample_size, replace=False)
training_sample = X_train[sample_idx].toarray()
joblib.dump(training_sample, models_dir / 'training_sample.joblib')
print(f"\nAll models saved to {models_dir}/")
if use_feedback:
corrections = load_feedback_corrections()
print(f"Feedback corrections incorporated: {len(corrections)}")
if __name__ == '__main__':
main()