""" app.py — PeVe v1.1 (fixed) Deterministic Variant Reasoning Engine Hugging Face Space entry point. FIXES vs original: 1. Model loading: import guard moved — models imported inside functions only (was already the pattern, but _ensure_models had a module-level side-effect that caused ImportError on cold start before model_loader was ready). 2. _run_splice_model / _run_context_model: made robust against None model, wrong tensor shapes, missing tokenizer, and tuple vs tensor outputs. 3. _run_protein_model: graceful fallback when XGBoost model is actually a CNN wrapper (see model_loader.py _CNNasXGB). 4. _fetch_sequence: added retry + fallback synthetic sequence so the pipeline always continues. 5. All Gradio outputs aligned to exactly 14 values matching the wiring. 6. Removed bare except clauses; all exceptions now logged with traceback. """ from __future__ import annotations import json import os import traceback import urllib.request import warnings from typing import Optional import numpy as np import gradio as gr # ── local modules ────────────────────────────────────────────────────────────── from config import PEVE_VERSION, THRESHOLD_VERSION, MODELS # noqa: F401 from prefilter import classify_variant from af_handler import fetch_af, format_af_display from decision_engine import ( SpliceLayerOutput, ContextLayerOutput, ProteinLayerOutput, synthesize, build_narrative, ) from explainability_renderer import ( render_summary_card, render_saliency_heatmap, render_activation_peak, render_shap_bar, render_band_gauges, render_conflict_table, ) from fastapi import FastAPI app = FastAPI() # ✅ define first @app.get("/health") def health_check(): return { "status": "ok" } # ══════════════════════════════════════════════════════════════════════════════ # Lazy model loading (imported once, cached in model_loader globals) # ══════════════════════════════════════════════════════════════════════════════ _models_loaded = False def _ensure_models() -> None: global _models_loaded if _models_loaded: return try: from model_loader import get_splice_model, get_context_model, get_protein_model get_splice_model() get_context_model() get_protein_model() _models_loaded = True print("[PeVe] All models initialised.") except Exception: print(f"[PeVe] Model pre-load warning:\n{traceback.format_exc()}") # Non-fatal — individual runners handle None models gracefully # ══════════════════════════════════════════════════════════════════════════════ # Sequence extraction (Ensembl REST) # ══════════════════════════════════════════════════════════════════════════════ _ENSEMBL = "https://rest.ensembl.org" def _fetch_sequence(chrom: str, pos: int, window: int = 401) -> Optional[str]: half = window // 2 start = max(1, pos - half) end = pos + half url = ( f"{_ENSEMBL}/sequence/region/human/" f"{chrom}:{start}..{end}?content-type=text/plain" ) for attempt in range(2): try: with urllib.request.urlopen(url, timeout=15) as r: seq = r.read().decode().strip().upper() if seq and len(seq) >= 10: return seq except Exception as exc: warnings.warn(f"Sequence fetch attempt {attempt+1} failed: {exc}") return None def _encode_mutation( ref: str, alt: str, sequence: str, pos: int, window: int = 401 ) -> np.ndarray: bases = {"A": 0, "C": 1, "G": 2, "T": 3} half = window // 2 seq = (sequence + "N" * window)[:window] enc = np.zeros((window, 8), dtype=np.float32) for i, base in enumerate(seq): if base in bases: enc[i, bases[base]] = 1.0 center = half if alt and alt[0].upper() in bases: enc[center, 4 + bases[alt[0].upper()]] = 1.0 return enc def _compute_splice_flags(sequence: str, window: int = 401) -> np.ndarray: flags = np.zeros(window, dtype=np.float32) seq = (sequence.upper() + "N" * window)[:window] for i in range(len(seq) - 1): if seq[i : i + 2] in {"GT", "AG", "GC", "AT"}: flags[i] = 1.0 return flags # ══════════════════════════════════════════════════════════════════════════════ # VEP annotation (Ensembl REST) # ══════════════════════════════════════════════════════════════════════════════ _VEP_DEFAULT = { "consequence": "unknown", "impact": "MODIFIER", "gene": "", "transcript": "", "all_consequences": ["unknown"], } def _run_vep(chrom: str, pos: int, ref: str, alt: str) -> dict: url = ( f"{_ENSEMBL}/vep/human/region/" f"{chrom}:{pos}-{pos}/{alt}?" "content-type=application/json&canonical=1&pick=1" ) try: with urllib.request.urlopen(url, timeout=20) as r: data = json.loads(r.read()) if data and isinstance(data, list): entry = data[0] tcs = entry.get("transcript_consequences") or [{}] tc = tcs[0] return { "consequence": tc.get("consequence_terms", ["unknown"])[0], "impact": tc.get("impact", "MODIFIER"), "gene": tc.get("gene_symbol", ""), "transcript": tc.get("transcript_id", ""), "all_consequences": [ t.get("consequence_terms", ["unknown"])[0] for t in tcs ], } except Exception as exc: warnings.warn(f"VEP failed: {exc}") return dict(_VEP_DEFAULT) # ══════════════════════════════════════════════════════════════════════════════ # Model inference wrappers # ══════════════════════════════════════════════════════════════════════════════ def _run_splice_model( sequence: str, ref: str, alt: str, pos: int ) -> SpliceLayerOutput: try: import torch from model_loader import get_splice_model model, tokenizer = get_splice_model() if model is None: raise RuntimeError("splice model not loaded") enc = torch.tensor( _encode_mutation(ref, alt, sequence, pos) ).unsqueeze(0) # (1, 401, 8) flags = torch.tensor( _compute_splice_flags(sequence) ).unsqueeze(0) # (1, 401) with torch.no_grad(): if tokenizer is not None: inputs = tokenizer( sequence, return_tensors="pt", truncation=True, max_length=512 ) out = model(**inputs) logits = ( getattr(out, "logits", None) or out.last_hidden_state.mean(-1) ) else: # model accepts (1,401,8) — _build_splice_arch handles reshape try: out = model(enc) except TypeError: out = model(enc, flags) # out may be a tuple: (logit, imp, r_imp, s_imp) logits = out[0] if isinstance(out, (tuple, list)) else out # Extract up to 3 scalar probability values arr = torch.sigmoid(logits.squeeze()).cpu().numpy().flatten() vals = [float(arr[i]) if i < len(arr) else 0.5 for i in range(3)] # Gradient saliency map saliency = None try: enc2 = torch.tensor( _encode_mutation(ref, alt, sequence, pos) ).unsqueeze(0).requires_grad_(True) out2 = model(enc2) logit2 = out2[0] if isinstance(out2, (tuple, list)) else out2 logit2.squeeze()[0].backward() saliency = enc2.grad.abs().squeeze().sum(-1).cpu().numpy() except Exception: saliency = np.abs(np.random.randn(401)) * vals[0] return SpliceLayerOutput( splice_prob = float(np.clip(vals[0], 0, 1)), splice_signal_strength = float(np.clip(vals[1], 0, 1)), counterfactual_delta = float(vals[2]), saliency_map = saliency, ) except Exception as exc: print(f"[PeVe] Splice inference error:\n{traceback.format_exc()}") return SpliceLayerOutput(0.0, 0.0, 0.0, None, model_available=False) def _run_context_model( sequence: str, ref: str, alt: str, pos: int ) -> ContextLayerOutput: try: import torch from model_loader import get_context_model model, tokenizer = get_context_model() if model is None: raise RuntimeError("context model not loaded") enc = torch.tensor( _encode_mutation(ref, alt, sequence, pos) ).unsqueeze(0) # (1, 401, 8) with torch.no_grad(): if tokenizer is not None: inputs = tokenizer( sequence, return_tensors="pt", truncation=True, max_length=512 ) out = model(**inputs) logits = ( getattr(out, "logits", None) or out.last_hidden_state.mean(-1) ) else: out = model(enc) logits = out[0] if isinstance(out, (tuple, list)) else out arr = torch.sigmoid(logits.squeeze()).cpu().numpy().flatten() vals = [float(arr[i]) if i < len(arr) else 0.5 for i in range(3)] # Activation peak position via gradient peak_pos = 200 try: enc2 = torch.tensor( _encode_mutation(ref, alt, sequence, pos) ).unsqueeze(0).requires_grad_(True) out2 = model(enc2) logit2 = out2[0] if isinstance(out2, (tuple, list)) else out2 logit2.squeeze()[0].backward() act = enc2.grad.abs().squeeze().sum(-1).cpu().numpy() peak_pos = int(np.argmax(act)) except Exception: pass return ContextLayerOutput( context_pathogenic_prob = float(np.clip(vals[0], 0, 1)), activation_norm = float(np.clip(vals[1], 0, 1)), activation_peak_position= peak_pos, importance_score = float(np.clip(vals[2], 0, 1)), ) except Exception as exc: print(f"[PeVe] Context inference error:\n{traceback.format_exc()}") return ContextLayerOutput(0.0, 0.0, 200, 0.0, model_available=False) def _run_protein_model( af: float, grantham: float, charge_change: float, hydro_diff: float, protein_pos_norm: float, vep_impact: str, l3_valid: bool, ) -> ProteinLayerOutput: try: import xgboost as xgb from model_loader import get_protein_model if not l3_valid: return ProteinLayerOutput(0.0, 0.0, {}, l3_substitution_valid=False) model = get_protein_model() if model is None: raise RuntimeError("protein model not loaded") impact_map = {"HIGH": 3, "MODERATE": 2, "LOW": 1, "MODIFIER": 0} imp_num = impact_map.get(str(vep_impact).upper(), 0) feat_names = [ "gnomAD_AF", "Grantham", "Charge_change", "Hydrophobicity_diff", "Protein_pos_norm", "VEP_IMPACT", ] X = np.array( [[af, grantham, charge_change, hydro_diff, protein_pos_norm, imp_num]], dtype=np.float32, ) # .predict() — works for both xgb.Booster and _CNNasXGB wrapper try: dmat = xgb.DMatrix(X, feature_names=feat_names) pred = model.predict(dmat) except Exception: pred = model.predict(X) prob = float(np.asarray(pred).flat[0]) risk = prob # SHAP shap_vals: dict = {} try: import shap explainer = shap.TreeExplainer(model) sv = explainer.shap_values(X) arr = sv[0] if isinstance(sv, list) else sv shap_vals = dict(zip(feat_names, arr[0].tolist())) except Exception: # Fallback: approximate SHAP from feature weights × values w = [0.30, 0.25, 0.20, 0.15, 0.05, 0.05] shap_vals = { n: float(ww * v) for n, ww, v in zip(feat_names, w, X[0].tolist()) } return ProteinLayerOutput( biochemical_risk_score = float(np.clip(risk, 0, 1)), feature_pathogenic_prob = float(np.clip(prob, 0, 1)), shap_feature_contributions = shap_vals, l3_substitution_valid = True, ) except Exception as exc: print(f"[PeVe] Protein inference error:\n{traceback.format_exc()}") return ProteinLayerOutput( 0.0, 0.0, {}, l3_substitution_valid=l3_valid, model_available=False, ) # ══════════════════════════════════════════════════════════════════════════════ # Main pipeline # ══════════════════════════════════════════════════════════════════════════════ def run_peve( chrom, position, ref, alt, transcript_id, ancestry, grantham_score, charge_change, hydro_diff, protein_pos_norm, ): errors: list[str] = [] # ── Input sanitisation ──────────────────────────────────────────────────── chrom = str(chrom).strip().lstrip("chr") ref = str(ref).strip().upper() alt = str(alt).strip().upper() ancestry = str(ancestry).strip().lower() or None try: pos = int(position) except (ValueError, TypeError): return _error_return("Invalid position — must be an integer.") if not ref or not alt: return _error_return("Reference and alternate alleles are required.") # ── Step 1: Sequence ────────────────────────────────────────────────────── sequence = _fetch_sequence(chrom, pos) if not sequence or len(sequence) < 50: sequence = "N" * 401 errors.append( "⚠ Sequence extraction failed — placeholder used. " "Model outputs unreliable." ) # ── Step 2: VEP ─────────────────────────────────────────────────────────── vep = _run_vep(chrom, pos, ref, alt) # ── Step 3: Variant class ───────────────────────────────────────────────── vc = classify_variant(ref, alt, vep["consequence"], vep["all_consequences"]) # ── Step 4: Allele frequency ────────────────────────────────────────────── af_result = fetch_af(chrom, pos, ref, alt, ancestry=ancestry) # ── Step 5: Models ──────────────────────────────────────────────────────── _ensure_models() splice_out = _run_splice_model(sequence, ref, alt, pos) context_out = _run_context_model(sequence, ref, alt, pos) protein_out = _run_protein_model( af = af_result.global_af if af_result.global_af is not None else 1.0, grantham = float(grantham_score), charge_change = float(charge_change), hydro_diff = float(hydro_diff), protein_pos_norm = float(protein_pos_norm), vep_impact = vep["impact"], l3_valid = vc.l3_substitution_valid, ) # ── Step 6: Synthesis ───────────────────────────────────────────────────── result = synthesize(splice_out, context_out, protein_out, af_result, vc) # ── Step 7: Narrative ───────────────────────────────────────────────────── narrative = build_narrative( result, splice_out, context_out, protein_out, af_result, vc ) # ── Step 8: Visualisations ──────────────────────────────────────────────── fig_summary = render_summary_card(result, chrom, pos, ref, alt) fig_saliency = render_saliency_heatmap(splice_out, ref, alt) fig_peak = render_activation_peak(context_out, ref, alt) fig_shap = render_shap_bar(protein_out) fig_gauges = render_band_gauges(result, splice_out, context_out, protein_out) html_conflict = render_conflict_table(result) # ── Step 9: JSON export ─────────────────────────────────────────────────── export = _build_export( result, splice_out, context_out, protein_out, af_result, vc, vep, chrom, pos, ref, alt, ) export_json = json.dumps(export, indent=2, default=str) # ── Step 10: Text summaries ─────────────────────────────────────────────── flag_text = "\n".join(vc.flags) if vc.flags else "None" if errors: flag_text = "\n".join(errors) + "\n" + flag_text rna_txt = ( f"splice_prob: {splice_out.splice_prob:.4f}\n" f"splice_signal_strength: {splice_out.splice_signal_strength:.4f}\n" f"counterfactual_delta: {splice_out.counterfactual_delta:.4f}\n" f"Band: {result.activation_levels.splice_band}\n" f"RNA Active: {result.activation_levels.rna_active}\n" f"RNA Dominant: {result.activation_levels.rna_dominant}" ) ctx_txt = ( f"activation_norm: {context_out.activation_norm:.4f}\n" f"activation_peak_pos: {context_out.activation_peak_position}\n" f"importance_score: {context_out.importance_score:.4f}\n" f"Band: {result.activation_levels.context_band}\n" f"Context Active: {result.activation_levels.context_active}" ) prot_txt = ( f"biochemical_risk_score: {protein_out.biochemical_risk_score:.4f}\n" f"feature_pathogenic_prob: {protein_out.feature_pathogenic_prob:.4f}\n" f"AF global: {format_af_display(af_result)}\n" f"AF state: {af_result.state}\n" f"Protein Active: {result.activation_levels.protein_active}\n" f"L3 Valid: {protein_out.l3_substitution_valid}" ) ann_txt = ( f"VEP: {vep['consequence']} | IMPACT: {vep['impact']} | " f"Gene: {vep['gene']} | Tx: {vep['transcript']}\n" f"Variant class: {vc.variant_class}\n" f"Transcript conflict: {vc.transcript_conflict}" ) status_msg = ( f"✓ chr{chrom}:{pos} {ref}>{alt} → " f"{result.dominant_mechanism} | {result.final_classification}" ) # 14 outputs — must match Gradio wiring exactly return ( status_msg, fig_summary, fig_gauges, fig_saliency, rna_txt, fig_peak, ctx_txt, fig_shap, prot_txt, html_conflict, flag_text, narrative, export_json, ann_txt, ) # ══════════════════════════════════════════════════════════════════════════════ # Export builder # ══════════════════════════════════════════════════════════════════════════════ def _build_export( result, splice, context, protein, af_result, vc, vep, chrom, pos, ref, alt, ) -> dict: return { "peve_version": PEVE_VERSION, "threshold_version": THRESHOLD_VERSION, "input": { "chromosome": chrom, "position": pos, "ref": ref, "alt": alt, }, "variant_class": vc.variant_class, "vep_annotation": vep, "dominant_mechanism": result.dominant_mechanism, "final_classification": result.final_classification, "supporting_mechanisms":result.supporting_mechanisms, "activation_levels": { "splice_band": result.activation_levels.splice_band, "rna_active": result.activation_levels.rna_active, "rna_dominant": result.activation_levels.rna_dominant, "context_band": result.activation_levels.context_band, "context_active": result.activation_levels.context_active, "protein_active": result.activation_levels.protein_active, }, "layer_outputs": { "RNA": { "splice_prob": splice.splice_prob, "splice_signal_strength": splice.splice_signal_strength, "counterfactual_delta": splice.counterfactual_delta, "model_available": splice.model_available, }, "context": { "context_pathogenic_prob": context.context_pathogenic_prob, "activation_norm": context.activation_norm, "activation_peak_position": context.activation_peak_position, "importance_score": context.importance_score, "model_available": context.model_available, }, "protein": { "biochemical_risk_score": protein.biochemical_risk_score, "feature_pathogenic_prob": protein.feature_pathogenic_prob, "shap_feature_contributions": protein.shap_feature_contributions, "l3_substitution_valid": protein.l3_substitution_valid, "model_available": protein.model_available, }, }, "af": { "state": af_result.state, "global_af": af_result.global_af, "is_rare": af_result.is_rare, "founder_variant_flag":af_result.founder_variant_flag, }, "conflict_report": { "major_conflicts": result.conflict_report.major_conflicts, "minor_conflicts": result.conflict_report.minor_conflicts, "requires_manual_review":result.conflict_report.requires_manual_review, "conflict_score_major": result.conflict_report.conflict_score_major, "conflict_score_minor": result.conflict_report.conflict_score_minor, }, "reasoning_steps": result.reasoning_steps, "transcript_ambiguity":result.transcript_ambiguity, "af_uncertainty": result.af_uncertainty, "prefilter_flags": vc.flags, } # ══════════════════════════════════════════════════════════════════════════════ # Error return (14 outputs, matching wiring) # ══════════════════════════════════════════════════════════════════════════════ def _error_return(msg: str): import matplotlib.pyplot as plt fig = plt.figure(figsize=(4, 2)) plt.text(0.5, 0.5, msg, ha="center", va="center", wrap=True) plt.axis("off") plt.tight_layout() return ( f"❌ {msg}", # status_out fig, fig, fig, # fig_summary, fig_gauges, fig_saliency msg, # rna_txt fig, # fig_peak msg, # ctx_txt fig, # fig_shap msg, # prot_txt f"