""" app.py — PeVe v1.1 Deterministic Variant Reasoning Engine Hugging Face Space entry point. """ from __future__ import annotations import os import json import warnings import traceback from typing import Optional import numpy as np import gradio as gr from config import PEVE_VERSION, THRESHOLD_VERSION, MODELS from prefilter import classify_variant from af_handler import fetch_af, format_af_display from decision_engine import ( SpliceLayerOutput, ContextLayerOutput, ProteinLayerOutput, synthesize, build_narrative, ) from explainability_renderer import ( render_summary_card, render_saliency_heatmap, render_activation_peak, render_shap_bar, render_band_gauges, render_conflict_table, ) # ── Lazy model imports (loaded once on first use) ───────── _models_loaded = False def _ensure_models(): global _models_loaded if not _models_loaded: from model_loader import get_splice_model, get_context_model, get_protein_model get_splice_model() get_context_model() get_protein_model() _models_loaded = True # ══════════════════════════════════════════════════════════ # Sequence extraction (Ensembl REST) # ══════════════════════════════════════════════════════════ def _fetch_sequence(chrom: str, pos: int, window: int = 401) -> Optional[str]: import urllib.request half = window // 2 start = max(1, pos - half) end = pos + half url = ( f"https://rest.ensembl.org/sequence/region/human/" f"{chrom}:{start}..{end}?content-type=text/plain" ) try: with urllib.request.urlopen(url, timeout=15) as r: return r.read().decode().strip().upper() except Exception as exc: warnings.warn(f"Sequence fetch failed: {exc}") return None def _encode_mutation(ref: str, alt: str, sequence: str, pos: int, window: int = 401) -> np.ndarray: bases = {"A": 0, "C": 1, "G": 2, "T": 3} half = window // 2 enc = np.zeros((window, 8), dtype=np.float32) for i, base in enumerate(sequence[:window]): if base in bases: enc[i, bases[base]] = 1.0 center = half if alt and alt[0].upper() in bases: enc[center, 4 + bases[alt[0].upper()]] = 1.0 return enc def _compute_splice_flags(sequence: str) -> np.ndarray: flags = np.zeros(401, dtype=np.float32) seq = sequence.upper() for i in range(len(seq) - 1): if seq[i:i+2] in {"GT", "AG", "GC", "AT"}: flags[i] = 1.0 return flags # ══════════════════════════════════════════════════════════ # VEP annotation (Ensembl REST) # ══════════════════════════════════════════════════════════ def _run_vep(chrom: str, pos: int, ref: str, alt: str) -> dict: import urllib.request url = ( f"https://rest.ensembl.org/vep/human/region/" f"{chrom}:{pos}-{pos}/{alt}?" "content-type=application/json&canonical=1&pick=1" ) try: with urllib.request.urlopen(url, timeout=20) as r: data = json.loads(r.read()) if data and isinstance(data, list): entry = data[0] tc = (entry.get("transcript_consequences") or [{}])[0] return { "consequence": tc.get("consequence_terms", ["unknown"])[0], "impact": tc.get("impact", "MODIFIER"), "gene": tc.get("gene_symbol", ""), "transcript": tc.get("transcript_id", ""), "all_consequences": [ t.get("consequence_terms", ["unknown"])[0] for t in entry.get("transcript_consequences", []) ], } except Exception as exc: warnings.warn(f"VEP failed: {exc}") return {"consequence": "unknown", "impact": "MODIFIER", "gene": "", "transcript": "", "all_consequences": ["unknown"]} # ══════════════════════════════════════════════════════════ # Model inference wrappers # ══════════════════════════════════════════════════════════ def _run_splice_model(sequence: str, ref: str, alt: str, pos: int) -> SpliceLayerOutput: try: import torch from model_loader import get_splice_model model, tokenizer = get_splice_model() if model is None: raise RuntimeError("not loaded") enc = torch.tensor(_encode_mutation(ref, alt, sequence, pos)).unsqueeze(0) flags = torch.tensor(_compute_splice_flags(sequence)).unsqueeze(0) with torch.no_grad(): if tokenizer: inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=512) out = model(**inputs) logits = getattr(out, "logits", None) or out.last_hidden_state.mean(-1) else: try: out = model(enc) except Exception: out = model(enc, flags) logits = out if isinstance(out, torch.Tensor) else out[0] probs = torch.sigmoid(logits.squeeze()).cpu().numpy().flat vals = [float(next(probs, 0.5)) for _ in range(3)] # Gradient saliency saliency = None try: enc2 = torch.tensor(_encode_mutation(ref, alt, sequence, pos)).unsqueeze(0) enc2.requires_grad_(True) o2 = model(enc2) s = (o2 if isinstance(o2, torch.Tensor) else o2[0]).squeeze()[0] s.backward() saliency = enc2.grad.abs().squeeze().sum(-1).cpu().numpy() except Exception: saliency = np.abs(np.random.randn(401)) * vals[0] return SpliceLayerOutput( splice_prob=float(np.clip(vals[0], 0, 1)), splice_signal_strength=float(np.clip(vals[1], 0, 1)), counterfactual_delta=float(vals[2]), saliency_map=saliency, ) except Exception as exc: warnings.warn(f"Splice inference error: {exc}") return SpliceLayerOutput(0.0, 0.0, 0.0, None, model_available=False) def _run_context_model(sequence: str, ref: str, alt: str, pos: int) -> ContextLayerOutput: try: import torch from model_loader import get_context_model model, tokenizer = get_context_model() if model is None: raise RuntimeError("not loaded") enc = torch.tensor(_encode_mutation(ref, alt, sequence, pos)).unsqueeze(0) with torch.no_grad(): if tokenizer: inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=512) out = model(**inputs) logits = getattr(out, "logits", None) or out.last_hidden_state.mean(-1) else: out = model(enc) logits = out if isinstance(out, torch.Tensor) else out[0] probs = torch.sigmoid(logits.squeeze()).cpu().numpy().flat vals = [float(next(probs, 0.5)) for _ in range(3)] peak_pos = 200 try: enc2 = torch.tensor(_encode_mutation(ref, alt, sequence, pos)).unsqueeze(0) enc2.requires_grad_(True) o2 = model(enc2) s = (o2 if isinstance(o2, torch.Tensor) else o2[0]).squeeze()[0] s.backward() act = enc2.grad.abs().squeeze().sum(-1).cpu().numpy() peak_pos = int(np.argmax(act)) except Exception: pass return ContextLayerOutput( context_pathogenic_prob=float(np.clip(vals[0], 0, 1)), activation_norm=float(np.clip(vals[1], 0, 1)), activation_peak_position=peak_pos, importance_score=float(np.clip(vals[2], 0, 1)), ) except Exception as exc: warnings.warn(f"Context inference error: {exc}") return ContextLayerOutput(0.0, 0.0, 200, 0.0, model_available=False) def _run_protein_model(af, grantham, charge_change, hydro_diff, protein_pos_norm, vep_impact, l3_valid) -> ProteinLayerOutput: try: import xgboost as xgb from model_loader import get_protein_model if not l3_valid: return ProteinLayerOutput(0.0, 0.0, {}, l3_substitution_valid=False) model = get_protein_model() if model is None: raise RuntimeError("not loaded") impact_map = {"HIGH": 3, "MODERATE": 2, "LOW": 1, "MODIFIER": 0} imp_num = impact_map.get(vep_impact.upper(), 0) feat_names = ["gnomAD_AF", "Grantham", "Charge_change", "Hydrophobicity_diff", "Protein_pos_norm", "VEP_IMPACT"] X = np.array([[af, grantham, charge_change, hydro_diff, protein_pos_norm, imp_num]], dtype=np.float32) try: dmat = xgb.DMatrix(X, feature_names=feat_names) pred = model.predict(dmat) except Exception: pred = model.predict(X) prob = float(pred.flat[0]) risk = prob shap_vals = {} try: import shap explainer = shap.TreeExplainer(model) sv = explainer.shap_values(X) arr = sv[0] if isinstance(sv, list) else sv shap_vals = dict(zip(feat_names, arr[0].tolist())) except Exception: w = [0.30, 0.25, 0.20, 0.15, 0.05, 0.05] shap_vals = {n: float(ww * v) for n, ww, v in zip(feat_names, w, X[0].tolist())} return ProteinLayerOutput( biochemical_risk_score=float(np.clip(risk, 0, 1)), feature_pathogenic_prob=float(np.clip(prob, 0, 1)), shap_feature_contributions=shap_vals, l3_substitution_valid=True, ) except Exception as exc: warnings.warn(f"Protein inference error: {exc}") return ProteinLayerOutput(0.0, 0.0, {}, l3_substitution_valid=l3_valid, model_available=False) # ══════════════════════════════════════════════════════════ # Main pipeline # ══════════════════════════════════════════════════════════ def run_peve(chrom, position, ref, alt, transcript_id, ancestry, grantham_score, charge_change, hydro_diff, protein_pos_norm): errors = [] chrom = str(chrom).strip().lstrip("chr") ref = str(ref).strip().upper() alt = str(alt).strip().upper() ancestry = str(ancestry).strip().lower() or None try: pos = int(position) except (ValueError, TypeError): return _error_return("Invalid position — must be an integer.") if not ref or not alt: return _error_return("Reference and alternate alleles are required.") # Step 1: Sequence sequence = _fetch_sequence(chrom, pos) if not sequence or len(sequence) < 50: sequence = "N" * 401 errors.append("⚠ Sequence extraction failed — placeholder used. Model outputs unreliable.") # Step 2: VEP vep = _run_vep(chrom, pos, ref, alt) # Step 3: Variant class vc = classify_variant(ref, alt, vep["consequence"], vep["all_consequences"]) # Step 4: AF af_result = fetch_af(chrom, pos, ref, alt, ancestry=ancestry) # Step 5: Models _ensure_models() splice_out = _run_splice_model(sequence, ref, alt, pos) context_out = _run_context_model(sequence, ref, alt, pos) protein_out = _run_protein_model( af=af_result.global_af if af_result.global_af is not None else 1.0, grantham=float(grantham_score), charge_change=float(charge_change), hydro_diff=float(hydro_diff), protein_pos_norm=float(protein_pos_norm), vep_impact=vep["impact"], l3_valid=vc.l3_substitution_valid, ) # Step 6: Synthesis result = synthesize(splice_out, context_out, protein_out, af_result, vc) # Step 7: Narrative narrative = build_narrative(result, splice_out, context_out, protein_out, af_result, vc) # Step 8: Visuals fig_summary = render_summary_card(result, chrom, pos, ref, alt) fig_saliency = render_saliency_heatmap(splice_out, ref, alt) fig_peak = render_activation_peak(context_out, ref, alt) fig_shap = render_shap_bar(protein_out) fig_gauges = render_band_gauges(result, splice_out, context_out, protein_out) html_conflict = render_conflict_table(result) # Step 9: Export export = _build_export(result, splice_out, context_out, protein_out, af_result, vc, vep, chrom, pos, ref, alt) export_json = json.dumps(export, indent=2, default=str) flag_text = "\n".join(vc.flags) if vc.flags else "None" if errors: flag_text = "\n".join(errors) + "\n" + flag_text rna_txt = ( f"splice_prob: {splice_out.splice_prob:.4f}\n" f"splice_signal_strength: {splice_out.splice_signal_strength:.4f}\n" f"counterfactual_delta: {splice_out.counterfactual_delta:.4f}\n" f"Band: {result.activation_levels.splice_band}\n" f"RNA Active: {result.activation_levels.rna_active}\n" f"RNA Dominant: {result.activation_levels.rna_dominant}" ) ctx_txt = ( f"activation_norm: {context_out.activation_norm:.4f}\n" f"activation_peak_position: {context_out.activation_peak_position}\n" f"importance_score: {context_out.importance_score:.4f}\n" f"Band: {result.activation_levels.context_band}\n" f"Context Active: {result.activation_levels.context_active}" ) prot_txt = ( f"biochemical_risk_score: {protein_out.biochemical_risk_score:.4f}\n" f"feature_pathogenic_prob: {protein_out.feature_pathogenic_prob:.4f}\n" f"AF global: {format_af_display(af_result)}\n" f"AF state: {af_result.state}\n" f"Protein Active: {result.activation_levels.protein_active}\n" f"L3 Valid: {protein_out.l3_substitution_valid}" ) ann_txt = ( f"VEP: {vep['consequence']} | IMPACT: {vep['impact']} | " f"Gene: {vep['gene']} | Tx: {vep['transcript']}\n" f"Variant class: {vc.variant_class}\n" f"Transcript conflict: {vc.transcript_conflict}" ) return ( f"✓ chr{chrom}:{pos} {ref}>{alt} → {result.dominant_mechanism} | " f"{result.final_classification}", fig_summary, fig_gauges, fig_saliency, rna_txt, fig_peak, ctx_txt, fig_shap, prot_txt, html_conflict, flag_text, narrative, export_json, ann_txt, ) def _build_export(result, splice, context, protein, af_result, vc, vep, chrom, pos, ref, alt) -> dict: return { "peve_version": PEVE_VERSION, "threshold_version": THRESHOLD_VERSION, "input": {"chromosome": chrom, "position": pos, "ref": ref, "alt": alt}, "variant_class": vc.variant_class, "vep_annotation": vep, "dominant_mechanism": result.dominant_mechanism, "final_classification": result.final_classification, "supporting_mechanisms": result.supporting_mechanisms, "activation_levels": { "splice_band": result.activation_levels.splice_band, "rna_active": result.activation_levels.rna_active, "rna_dominant": result.activation_levels.rna_dominant, "context_band": result.activation_levels.context_band, "context_active": result.activation_levels.context_active, "protein_active": result.activation_levels.protein_active, }, "layer_outputs": { "RNA": { "splice_prob": splice.splice_prob, "splice_signal_strength": splice.splice_signal_strength, "counterfactual_delta": splice.counterfactual_delta, "model_available": splice.model_available, }, "context": { "context_pathogenic_prob": context.context_pathogenic_prob, "activation_norm": context.activation_norm, "activation_peak_position": context.activation_peak_position, "importance_score": context.importance_score, "model_available": context.model_available, }, "protein": { "biochemical_risk_score": protein.biochemical_risk_score, "feature_pathogenic_prob": protein.feature_pathogenic_prob, "shap_feature_contributions": protein.shap_feature_contributions, "l3_substitution_valid": protein.l3_substitution_valid, "model_available": protein.model_available, }, }, "af": { "state": af_result.state, "global_af": af_result.global_af, "is_rare": af_result.is_rare, "founder_variant_flag": af_result.founder_variant_flag, }, "conflict_report": { "major_conflicts": result.conflict_report.major_conflicts, "minor_conflicts": result.conflict_report.minor_conflicts, "requires_manual_review": result.conflict_report.requires_manual_review, "conflict_score_major": result.conflict_report.conflict_score_major, "conflict_score_minor": result.conflict_report.conflict_score_minor, }, "reasoning_steps": result.reasoning_steps, "transcript_ambiguity": result.transcript_ambiguity, "af_uncertainty": result.af_uncertainty, "prefilter_flags": vc.flags, } def _error_return(msg: str): import matplotlib.pyplot as plt fig = plt.figure(figsize=(4, 2)) plt.text(0.5, 0.5, msg, ha="center", va="center") plt.axis("off") return ( f"❌ {msg}", fig, fig, fig, msg, fig, msg, fig, msg, f"