""" backend.py ========== Mutation Explainability Intelligence System — Backend Contains: - Ensembl sequence fetching - All matplotlib plot functions (return PIL Images) - Main run_pipeline() function - Helper utilities (_save_json, _error_outputs, _summary_md) Imported by app.py (Gradio UI layer). """ from __future__ import annotations import io import logging import os import tempfile import traceback import time from functools import lru_cache import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap import requests from PIL import Image as PILImage from model_loader import ModelRegistry, encode_for_v2, find_mutation_pos from explainability_engine import ( extract_splice_signals, extract_v4_signals, extract_classic_signals, compute_cross_model_analysis, ) from decision_engine import build_decision, DecisionResult logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", ) logger = logging.getLogger("mutation_xai.backend") # ── Global model registry (lazy-loaded on first request) ───────────────────── REGISTRY = ModelRegistry(hf_token=os.environ.get("HF_TOKEN")) # ═══════════════════════════════════════════════════════════════════════════════ # Ensembl sequence fetching # ═══════════════════════════════════════════════════════════════════════════════ ENSEMBL_URL = "https://rest.ensembl.org/sequence/region/human" WINDOW_HALF = 49 # 49 + 1 + 49 = 99 bp (matches all three models) @lru_cache(maxsize=256) def _fetch_ensembl(chrom: str, start: int, end: int) -> str: chrom = chrom.lstrip("chrCHR").strip() region = f"{chrom}:{start}..{end}:1" url = f"{ENSEMBL_URL}/{region}" for attempt in range(3): try: r = requests.get(url, params={"content-type": "application/json"}, timeout=15) if r.status_code == 429: time.sleep(int(r.headers.get("Retry-After", 5))) continue r.raise_for_status() data = r.json() if isinstance(data, list): data = data[0] return data.get("seq", "").upper() except Exception as e: if attempt == 2: raise RuntimeError(f"Ensembl API failed after 3 attempts: {e}") time.sleep(1.5 * (2 ** attempt)) return "" def fetch_window(chrom: str, pos: int) -> tuple[str, int]: """Fetch 99bp window centred on pos. Returns (ref_seq, mutation_pos_in_window).""" start = max(1, pos - WINDOW_HALF) end = pos + WINDOW_HALF seq = _fetch_ensembl(chrom.strip(), start, end) if not seq: raise ValueError(f"Empty sequence returned for chr{chrom}:{start}-{end}") seq = (seq + "N" * 99)[:99] mut_pos = max(0, min(98, pos - start)) return seq, mut_pos # ═══════════════════════════════════════════════════════════════════════════════ # Plot colour constants # ═══════════════════════════════════════════════════════════════════════════════ _BG = "#0D1117" _TEXT = "#E6EDF3" _MUTED = "#7D8590" _BLUE = "#58A6FF" _GREEN = "#3FB950" _RED = "#F85149" _ORG = "#D29922" _CMAP_ACT = LinearSegmentedColormap.from_list( "act", [(0.04, 0.22, 0.47), (0.96, 0.96, 0.96), (0.72, 0.05, 0.12)], N=256) _CMAP_SPL = LinearSegmentedColormap.from_list( "spl", [(0.0, "#f7f7f7"), (0.3, "#fee08b"), (0.6, "#fc8d59"), (1.0, "#d73027")]) # ═══════════════════════════════════════════════════════════════════════════════ # Plot utilities # ═══════════════════════════════════════════════════════════════════════════════ def _to_pil(fig) -> PILImage.Image: """Render matplotlib figure → PIL Image, then close the figure.""" buf = io.BytesIO() fig.savefig(buf, format="png", dpi=110, bbox_inches="tight", facecolor=fig.get_facecolor()) buf.seek(0) img = PILImage.open(buf).copy() plt.close(fig) return img def _empty_image(msg: str = "No data") -> PILImage.Image: fig, ax = plt.subplots(figsize=(6, 1.5), facecolor=_BG) ax.text(0.5, 0.5, msg, ha="center", va="center", color=_MUTED, fontsize=11) ax.axis("off") return _to_pil(fig) def _style(ax, title: str): ax.set_title(title, color=_TEXT, fontsize=9, loc="left", pad=4, fontweight="bold") for sp in ["top", "right"]: ax.spines[sp].set_visible(False) ax.spines["left"].set_color("#333") ax.spines["bottom"].set_color("#333") ax.tick_params(colors=_TEXT, labelsize=7) # ═══════════════════════════════════════════════════════════════════════════════ # Plot functions — all return PIL Images # ═══════════════════════════════════════════════════════════════════════════════ def plot_activation(profile: np.ndarray, mut_pos: int, label: str, prob: float) -> PILImage.Image: """CNN conv3 activation norm heatmap.""" imp = profile.copy() if imp.max() > 0: imp /= imp.max() fig, ax = plt.subplots(figsize=(14, 2.4), facecolor=_BG) ax.set_facecolor(_BG) im = ax.imshow(imp[np.newaxis, :], aspect="auto", cmap=_CMAP_ACT, vmin=0, vmax=1, extent=[-0.5, 98.5, 0, 1]) if mut_pos >= 0: ax.axvline(x=mut_pos, color=_GREEN, linewidth=2.0, linestyle="--", label=f"Mutation @ {mut_pos}") ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.6, loc="upper right") cb = fig.colorbar(im, ax=ax, pad=0.01) cb.set_label("Activation", color=_TEXT, fontsize=8) cb.ax.tick_params(colors=_TEXT, labelsize=7) ax.set_xlabel("Nucleotide position (99 bp)", color=_TEXT, fontsize=9) ax.set_xticks(range(0, 99, 10)) ax.set_yticks([]) _style(ax, f"conv3 Activation — {label} (prob={prob:.4f})") fig.tight_layout() return _to_pil(fig) def plot_splice_risk(ref_seq: str, mut_pos: int) -> PILImage.Image: """Splice site distance risk heatmap.""" seq = (ref_seq.upper() + "N" * 99)[:99] scores = np.zeros(99) donors, acceptors = [], [] for i in range(len(seq) - 1): if seq[i:i+2] == "GT": donors.append(i) if seq[i:i+2] == "AG": acceptors.append(i) for p in donors: for d in range(-8, 9): if 0 <= p + d < 99: scores[p + d] = max(scores[p + d], 0.5) for p in acceptors: for d in range(-8, 9): if 0 <= p + d < 99: scores[p + d] = max(scores[p + d], 0.5) for p in donors: if 0 <= p < 99: scores[p] = 1.0 for p in acceptors: if 0 <= p < 99: scores[p] = max(scores[p], 0.8) fig, ax = plt.subplots(figsize=(14, 2.4), facecolor=_BG) ax.set_facecolor(_BG) im = ax.imshow(scores[np.newaxis, :], aspect="auto", cmap=_CMAP_SPL, vmin=0, vmax=1, extent=[-0.5, 98.5, 0, 1]) if mut_pos >= 0: ax.axvline(x=mut_pos, color=_BLUE, linewidth=2.0, linestyle="--", label=f"Mutation @ {mut_pos}") ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.6, loc="upper right") cb = fig.colorbar(im, ax=ax, pad=0.01) cb.set_label("Splice risk", color=_TEXT, fontsize=8) cb.ax.tick_params(colors=_TEXT, labelsize=7) ax.set_xlabel("Nucleotide position (99 bp)", color=_TEXT, fontsize=9) ax.set_xticks(range(0, 99, 10)) ax.set_yticks([]) _style(ax, "Splice Distance Risk — GT donor / AG acceptor") fig.tight_layout() return _to_pil(fig) def plot_gradient(attr: np.ndarray, mut_pos: int, label: str) -> PILImage.Image: """Input-gradient attribution heatmap.""" fig, ax = plt.subplots(figsize=(14, 2.4), facecolor=_BG) ax.set_facecolor(_BG) im = ax.imshow(attr[np.newaxis, :], aspect="auto", cmap="PuOr", vmin=0, vmax=1, extent=[-0.5, 98.5, 0, 1]) if mut_pos >= 0: ax.axvline(x=mut_pos, color=_GREEN, linewidth=2.0, linestyle="--", label=f"Mutation @ {mut_pos}") ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.6, loc="upper right") cb = fig.colorbar(im, ax=ax, pad=0.01) cb.set_label("Attribution", color=_TEXT, fontsize=8) cb.ax.tick_params(colors=_TEXT, labelsize=7) ax.set_xlabel("Nucleotide position (99 bp)", color=_TEXT, fontsize=9) ax.set_xticks(range(0, 99, 10)) ax.set_yticks([]) _style(ax, f"Gradient Attribution — {label}") fig.tight_layout() return _to_pil(fig) def plot_counterfactual(cf_table: list, orig_prob: float, cf_delta: float) -> PILImage.Image: """Bar chart of all alternative substitution probabilities.""" if not cf_table: return _empty_image("No counterfactual data") labels = [r["mutation"] for r in cf_table] probs = [r["probability"] for r in cf_table] max_p, min_p = max(probs), min(probs) bar_colors = [_RED if p == max_p else (_BLUE if p == min_p else "#74add1") for p in probs] fig, ax = plt.subplots(figsize=(10, 3.4), facecolor=_BG) ax.set_facecolor(_BG) bars = ax.bar(labels, probs, color=bar_colors, edgecolor="#444", linewidth=0.7) ax.axhline(0.5, color=_MUTED, linestyle="--", linewidth=1.0, label="Boundary (0.5)") ax.axhline(orig_prob, color=_ORG, linestyle="-.", linewidth=1.5, label=f"Original ({orig_prob:.3f})") ax.set_ylim(0, 1.05) ax.set_xlabel("Alternative mutation", color=_TEXT, fontsize=10) ax.set_ylabel("Pathogenicity prob.", color=_TEXT, fontsize=10) ax.tick_params(colors=_TEXT) ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.5) for b, p in zip(bars, probs): ax.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.01, f"{p:.3f}", ha="center", va="bottom", fontsize=8, color=_TEXT) for sp in ["top", "right"]: ax.spines[sp].set_visible(False) ax.spines["left"].set_color("#333") ax.spines["bottom"].set_color("#333") ax.set_title( f"Counterfactual Analysis | Δ={cf_delta:.4f} | range {min_p:.3f}–{max_p:.3f}", color=_TEXT, fontsize=10, loc="left") fig.tight_layout() return _to_pil(fig) def plot_ablation(ablation: dict) -> PILImage.Image: """Horizontal bar chart of feature ablation causal effects.""" labels = [ "Splice features\n(donor/acceptor/region)", "Region features\n(exon/intron flags)", "Mutation type\n(one-hot encoding)", ] deltas = [ablation["splice_causal_effect"], ablation["region_causal_effect"], ablation["mutation_causal_effect"]] pcts = [ablation["splice_pct"], ablation["region_pct"], ablation["mutation_pct"]] fig, ax = plt.subplots(figsize=(9, 2.8), facecolor=_BG) ax.set_facecolor(_BG) bars = ax.barh(labels, deltas, color=[_RED, _ORG, _BLUE], edgecolor="#444", linewidth=0.6) ax.set_xlabel("Probability Δ when feature ablated", color=_TEXT, fontsize=9) ax.tick_params(colors=_TEXT, labelsize=8) ax.set_title( f"Feature Ablation | baseline prob={ablation['baseline_probability']:.4f}", color=_TEXT, fontsize=10, loc="left") for b, d, p in zip(bars, deltas, pcts): ax.text(b.get_width() + 0.002, b.get_y() + b.get_height() / 2, f" Δ{d:.4f} ({p}%)", va="center", fontsize=9, color=_TEXT) ax.set_xlim(0, max(deltas + [0.01]) * 1.7) for sp in ["top", "right"]: ax.spines[sp].set_visible(False) ax.spines["left"].set_color("#333") ax.spines["bottom"].set_color("#333") fig.tight_layout() return _to_pil(fig) def plot_xai_panel(xai) -> PILImage.Image: """Bar chart of cross-model explainability metrics.""" labels = ["Model\nAgreement", "XAI\nStrength", "CF\nMagnitude", "Locality\nScore", "Concentration\nIndex"] values = [ xai.model_agreement, xai.explainability_strength, min(xai.counterfactual_magnitude / 0.4, 1.0), xai.cross_model_locality_score, xai.signal_concentration_index, ] colors = [_GREEN if v >= 0.65 else (_ORG if v >= 0.40 else _RED) for v in values] fig, ax = plt.subplots(figsize=(10, 2.8), facecolor=_BG) ax.set_facecolor(_BG) bars = ax.bar(labels, values, color=colors, edgecolor="#444", linewidth=0.6, width=0.5) ax.axhline(0.65, color=_GREEN, linestyle="--", linewidth=0.8, alpha=0.6, label="High ≥0.65") ax.axhline(0.40, color=_ORG, linestyle="--", linewidth=0.8, alpha=0.6, label="Moderate ≥0.40") ax.set_ylim(0, 1.15) ax.set_ylabel("Score (0–1)", color=_TEXT, fontsize=9) ax.tick_params(colors=_TEXT, labelsize=8) ax.legend(fontsize=8, facecolor=_BG, labelcolor=_TEXT, framealpha=0.4, loc="upper right") for b, v in zip(bars, values): ax.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.02, f"{v:.3f}", ha="center", fontsize=9, color=_TEXT) for sp in ["top", "right"]: ax.spines[sp].set_visible(False) ax.spines["left"].set_color("#333") ax.spines["bottom"].set_color("#333") ax.set_title("Explainability Metrics Panel", color=_TEXT, fontsize=10, loc="left") fig.tight_layout() return _to_pil(fig) # ═══════════════════════════════════════════════════════════════════════════════ # Main pipeline # ═══════════════════════════════════════════════════════════════════════════════ def run_pipeline(chrom, position, ref_base, alt_base, exon_flag, intron_flag): """ Main Gradio callback. Returns tuple of 13 outputs matching the UI output list in app.py. """ chrom = str(chrom).strip() ref_base = str(ref_base).strip().upper() alt_base = str(alt_base).strip().upper() exon_flag = int(exon_flag) intron_flag = int(intron_flag) try: pos = int(str(position).strip().replace(",", "")) except (ValueError, TypeError): return _error_outputs(f"Invalid position: '{position}'") if ref_base not in "ACGT" or len(ref_base) != 1: return _error_outputs(f"Reference base must be A/C/G/T, got: '{ref_base}'") if alt_base not in "ACGT" or len(alt_base) != 1: return _error_outputs(f"Alternate base must be A/C/G/T, got: '{alt_base}'") if ref_base == alt_base: return _error_outputs("Reference and alternate bases are identical.") try: ref_seq, mut_pos = fetch_window(chrom, pos) actual = ref_seq[mut_pos].upper() if actual != ref_base: return _error_outputs( f"Reference mismatch at chr{chrom}:{pos}: " f"genome has '{actual}', you entered '{ref_base}'. " "Check chromosome/position or use the genome base shown." ) mut_seq = ref_seq[:mut_pos] + alt_base + ref_seq[mut_pos + 1:] splice_sig = extract_splice_signals(REGISTRY.splice, ref_seq, mut_seq, exon_flag, intron_flag) v4_sig = extract_v4_signals(REGISTRY.v4, ref_seq, mut_seq, exon_flag, intron_flag) classic_sig = extract_classic_signals(REGISTRY.classic, ref_seq, mut_seq) xai = compute_cross_model_analysis(splice_sig, v4_sig, classic_sig, mut_pos) result = build_decision( chrom=chrom, pos=pos, ref=ref_base, alt=alt_base, ref_seq=ref_seq, mut_seq=mut_seq, mutation_pos=mut_pos, splice=splice_sig, v4=v4_sig, classic=classic_sig, xai=xai, ) mp = result.mutation_pos plots = { "xai": plot_xai_panel(result.xai), "spl_act": plot_activation(result.splice.conv3_profile, mp, "Splice", result.splice.probability), "spl_risk": plot_splice_risk(result.ref_seq, mp), "spl_grad": plot_gradient(result.splice.gradient_attribution, mp, "Splice"), "v4_act": plot_activation(result.v4.conv3_profile, mp, "V4", result.v4.probability), "v4_grad": plot_gradient(result.v4.gradient_attribution, mp, "V4"), "cls_act": plot_activation(result.classic.conv3_profile, mp, "Classic", result.classic.probability), "cf": plot_counterfactual(result.splice.counterfactual_table, result.splice.probability, result.splice.counterfactual_delta), "abl": plot_ablation(result.splice.ablation), } json_str = result.to_json() json_path = _save_json(json_str) demo_note = ( "\n> ⚠️ **DEMO MODE** — random weights. Set HF_TOKEN for real predictions.\n" if REGISTRY.demo_mode else "" ) return ( _summary_md(result, demo_note), # 0 summary markdown result.final_explanation, # 1 explanation text plots["xai"], # 2 XAI metrics plots["spl_act"], # 3 splice activation plots["spl_risk"], # 4 splice risk heatmap plots["spl_grad"], # 5 splice gradient plots["v4_act"], # 6 v4 activation plots["v4_grad"], # 7 v4 gradient plots["cls_act"], # 8 classic activation plots["cf"], # 9 counterfactual chart plots["abl"], # 10 ablation chart json_str, # 11 json text json_path, # 12 download file path ) except Exception as exc: logger.error("Pipeline error: %s\n%s", exc, traceback.format_exc()) return _error_outputs(f"{exc}\n\n```\n{traceback.format_exc()}\n```") # ═══════════════════════════════════════════════════════════════════════════════ # Helpers # ═══════════════════════════════════════════════════════════════════════════════ def _save_json(json_str: str) -> str: tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w", encoding="utf-8") tmp.write(json_str) tmp.close() return tmp.name def _error_outputs(msg: str): """Return 13-element error tuple matching run_pipeline() output signature.""" blank = _empty_image("—") return ( f"❌ **Error**\n\n{msg}", # 0 summary markdown "", # 1 explanation text blank, blank, blank, blank, # 2-5 blank, blank, blank, # 6-8 blank, blank, # 9-10 "{}", # 11 json text None, # 12 download file ) def _summary_md(r: DecisionResult, note: str) -> str: """Build the summary markdown shown in the explanation panel.""" mech_icon = { "Splice-driven": "🔀", "Protein-driven": "🧬", "Consensus": "✅", "Ambiguous": "⚠️", }.get(r.dominant_mechanism, "❓") tier_icon = { "PATHOGENIC": "🔴", "LIKELY PATHOGENIC": "🟠", "POSSIBLY PATHOGENIC": "🟡", "LIKELY BENIGN": "🟢", "BENIGN": "🟢", }.get(r.risk_tier, "⚪") conf_icon = {"High": "🔵", "Moderate": "🟡", "Low": "🔴"}.get(r.confidence, "⚪") return f"""{note} ## {tier_icon} Risk Tier: **{r.risk_tier}** | Field | Value | |---|---| | **Variant** | `chr{r.chrom}:g.{r.pos}{r.ref}>{r.alt}` | | **Unified Probability** | `{r.unified_probability:.4f}` | | **Dominant Mechanism** | {mech_icon} {r.dominant_mechanism} | | **Confidence** | {conf_icon} {r.confidence} | | **Splice Model** | `{r.splice.probability:.4f}` — {r.splice.risk_tier} | | **V4 Model** | `{r.v4.probability:.4f}` | | **Classic Model** | `{r.classic.probability:.4f}` | --- ### Explainability Metrics | Metric | Value | |---|---| | **Mutation Peak Ratio** | `{r.xai.mutation_peak_ratio:.4f}` | | **Counterfactual Magnitude** | `{r.xai.counterfactual_magnitude:.4f}` | | **Cross-Model Locality** | `{r.xai.cross_model_locality_score:.4f}` | | **Signal Concentration** | `{r.xai.signal_concentration_index:.4f}` | | **XAI Strength Score** | `{r.xai.explainability_strength:.4f}` | | **Activation Pattern** | `{r.xai.activation_pattern_type}` | | **Model Agreement** | `{r.xai.model_agreement:.4f}` | --- ### Interpretation **Splice:** {r.splice_analysis[:320]}{'…' if len(r.splice_analysis) > 320 else ''} **Protein:** {r.protein_analysis[:260]}{'…' if len(r.protein_analysis) > 260 else ''} **Agreement:** {r.agreement_analysis[:260]}{'…' if len(r.agreement_analysis) > 260 else ''} """ #Content is user-generated and unverified