import numpy as np import pandas as pd import gradio as gr import joblib from functools import lru_cache from huggingface_hub import hf_hub_download # ----------------------------- # Configuration # ----------------------------- MODEL_REPO = "Rolando666/RSF-post-ICU-discharge" MODEL_FILES = { (30, "age0_24"): "rsf_30d_age0_24.joblib", (30, "age25_64"): "rsf_30d_age25_64.joblib", (30, "age65plus"): "rsf_30d_age65plus.joblib", (60, "age0_24"): "rsf_60d_age0_24.joblib", (60, "age25_64"): "rsf_60d_age25_64.joblib", (60, "age65plus"): "rsf_60d_age65plus.joblib", } REGION_PREFIX = "region_" # ----------------------------- # Labels + slider ranges (Table 1) # ----------------------------- LABELS = { "charlindex": "Charlson comorbidity index", "score": "Overall frailty score", "diag_count": "Number of diagnoses at ICU", "hospital_beds": "Hospital beds", "time_in_icu": "Time in ICU (months)", "time_in_i": "Time in ICU (months)", "ami": "Acute myocardial infarction", "chf": "Congestive heart failure", "pvd": "Peripheral vascular disease", "stroke": "Cerebrovascular disease / stroke", "copd": "Chronic pulmonary disease (COPD)", "dementia": "Dementia", "diabetes": "Diabetes", "renal": "Renal disease", "cancer": "Cancer", "liver_mild": "Mild liver disease", "liver_severe": "Severe liver disease", } CONTINUOUS_RANGES = { "charlindex": (0, 15, 1), "score": (0.0, 36.8, 0.1), "diag_count": (1, 20, 1), "hospital_beds": (293, 2605, 1), "time_in_icu": (0.0, 10.0, 0.1), "time_in_i": (0.0, 10.0, 0.1), } # Mean defaults from Table 1 CONTINUOUS_MEANS = { "charlindex": 3.46, "score": 5.83, # overall frailty score "diag_count": 12.55, "hospital_beds": 1244, "time_in_icu": 0.42, # months "time_in_i": 0.42, # truncated variant } BINARY_HINTS = { "ami", "chf", "pvd", "stroke", "copd", "dementia", "diabetes", "renal", "cancer", "liver_mild", "liver_severe", } SEX_COL_CANDIDATES = {"sexcat_male", "sexcat_female", "sexcat_other"} ETHN_COL_CANDIDATES = {"ethn_white_white", "ethn_other", "ethn_white", "ethnicity_white"} # ----------------------------- # Utilities # ----------------------------- def _normalize(s: str) -> str: return s.strip().lower().replace("-", "_").replace(" ", "_") def human_label(feature_name: str) -> str: key = _normalize(feature_name) return LABELS.get(key, LABELS.get(feature_name, feature_name)) def age_group_from_age(age: float) -> str: if age < 25: return "age0_24" if age < 65: return "age25_64" return "age65plus" def is_region_feature(name: str) -> bool: return _normalize(name).startswith(REGION_PREFIX) def is_binary_feature(name: str) -> bool: n = _normalize(name) if n in BINARY_HINTS: return True if n.startswith(REGION_PREFIX): return True if n.endswith("_yes") or n.endswith("_no"): return True return False def is_continuous_slider_feature(name: str) -> bool: n = _normalize(name) if n in CONTINUOUS_RANGES: return True for k in CONTINUOUS_RANGES: if k in n: return True return False def _find_age_col(feature_names: list[str]) -> str | None: lowered = {c: _normalize(c) for c in feature_names} for c, cl in lowered.items(): if cl in {"age", "age_years", "ageyrs", "edad"}: return c for c, cl in lowered.items(): if "age" in cl: return c return None def _region_suffixes(feature_names: list[str]) -> list[str]: sufs = [] for f in feature_names: if is_region_feature(f): sufs.append(f[len(REGION_PREFIX):]) seen, out = set(), [] for s in sufs: if s not in seen: out.append(s) seen.add(s) return out def _pretty_token(x: str) -> str: x = x.replace("_", " ").strip() return x[:1].upper() + x[1:] if x else x @lru_cache(maxsize=16) def load_bundle(horizon: int, age_group: str) -> dict: fname = MODEL_FILES[(horizon, age_group)] path = hf_hub_download(repo_id=MODEL_REPO, filename=fname) return joblib.load(path) @lru_cache(maxsize=1) def base_feature_names() -> list[str]: b = load_bundle(30, "age25_64") return list(b["feature_names"]) def survival_at_horizon(estimator, X: pd.DataFrame, horizon_days: int) -> float: sfs = estimator.predict_survival_function(X) # list of StepFunction (usually) sf0 = sfs[0] if callable(sf0): v = sf0(float(horizon_days)) return float(np.asarray(v).reshape(-1)[0]) if hasattr(sf0, "x") and hasattr(sf0, "y"): times = np.asarray(sf0.x, dtype=float) vals = np.asarray(sf0.y, dtype=float) idx = int(np.searchsorted(times, float(horizon_days), side="right") - 1) if idx < 0: return 1.0 return float(vals[idx]) # last resort sf_arr = estimator.predict_survival_function(X, return_array=True) times = getattr(estimator, "unique_times_", getattr(estimator, "event_times_", None)) if times is None: raise AttributeError("Model has no time grid attribute and survival is not a StepFunction.") times = np.asarray(times, dtype=float) idx = int(np.argmin(np.abs(times - float(horizon_days)))) return float(sf_arr[0, idx]) def risk_category_from_survival(surv: float) -> str: """ Categorize risk using survival probability thresholds (user-defined). surv in [0,1]. """ s = float(surv) if s >= 0.90: return "very low risk" if 0.80 <= s < 0.90: return "low risk" if 0.70 <= s < 0.80: return "low-moderate risk" if 0.60 <= s < 0.70: return "moderate-low risk" if 0.50 <= s < 0.60: return "moderate risk" if 0.40 <= s < 0.60: return "moderate-high risk" if 0.20 <= s < 0.40: return "high risk" return "very high risk" def _set_demographics(full: dict, feature_names: list[str], sex: str, ethnicity: str): # SEX mapping sex = (sex or "").lower().strip() sex_cols_present = [c for c in feature_names if _normalize(c) in SEX_COL_CANDIDATES] for c in sex_cols_present: full[c] = 0 if sex_cols_present: if sex == "male": for c in sex_cols_present: if _normalize(c) == "sexcat_male": full[c] = 1 elif sex == "female": for c in sex_cols_present: if _normalize(c) == "sexcat_female": full[c] = 1 else: for c in sex_cols_present: if _normalize(c) == "sexcat_other": full[c] = 1 # ETHNICITY mapping (model uses white vs non-white; UI shows more categories) eth = (ethnicity or "").lower().strip() eth_cols_present = [c for c in feature_names if _normalize(c) in ETHN_COL_CANDIDATES] for c in eth_cols_present: full[c] = 0 if eth_cols_present: if eth == "white": for c in eth_cols_present: if _normalize(c) in {"ethn_white_white", "ethn_white", "ethnicity_white"}: full[c] = 1 break else: # Non-white categories map to 0 for the white indicator; if an "other" column exists, set it. for c in eth_cols_present: if _normalize(c) == "ethn_other": full[c] = 1 break # if no ethn_other exists, leaving all-zeros is consistent with white-as-reference coding # ----------------------------- # Build components # ----------------------------- FEATURES = base_feature_names() AGE_COL = _find_age_col(FEATURES) REGION_SUFFIXES = _region_suffixes(FEATURES) HIDDEN_DEMOGRAPHIC_COLS = SEX_COL_CANDIDATES.union(ETHN_COL_CANDIDATES) continuous_features, binary_features, other_numeric_features = [], [], [] for f in FEATURES: if f == AGE_COL: continue if is_region_feature(f): continue if _normalize(f) in HIDDEN_DEMOGRAPHIC_COLS: continue if is_continuous_slider_feature(f) and not is_binary_feature(f): continuous_features.append(f) elif is_binary_feature(f): binary_features.append(f) else: other_numeric_features.append(f) continuous_components: dict[str, gr.Slider] = {} binary_components: dict[str, gr.Checkbox] = {} other_numeric_components: dict[str, gr.Number] = {} for f in continuous_features: nf = _normalize(f) key = nf if nf in CONTINUOUS_RANGES else None if key is None: for k in CONTINUOUS_RANGES: if k in nf: key = k break mn, mx, st = CONTINUOUS_RANGES.get(key, (0.0, 100.0, 1.0)) # Default: mean if available, otherwise mid-point mean_val = CONTINUOUS_MEANS.get(key, (mn + mx) / 2) # Clip mean to range and snap to step mean_val = max(mn, min(mx, mean_val)) if st: mean_val = round((mean_val - mn) / st) * st + mn mean_val = max(mn, min(mx, mean_val)) continuous_components[f] = gr.Slider( minimum=mn, maximum=mx, step=st, value=mean_val, label=human_label(f), ) for f in other_numeric_features: other_numeric_components[f] = gr.Number(label=human_label(f), value=0) for f in binary_features: binary_components[f] = gr.Checkbox(label=human_label(f), value=False) # ----------------------------- # Visual rendering helpers # ----------------------------- def _risk_badge_html(category: str) -> str: # no custom colors required; use simple semantic styling cat = (category or "").strip().lower() label = category # Minimal CSS using grayscale but clear hierarchy return f"""
Risk category
{label}
""" def _risk_gauge_html(risk: float) -> str: """ risk in [0,1]. Displays a green→yellow→red bar with a simple vertical marker line. """ r = float(max(0.0, min(1.0, risk))) pct = r * 100.0 return f"""
Risk gauge (mortality risk = 1 − survival)
{pct:.1f}%
LowModerateHigh
""" def age_group_label(age_group: str) -> str: mapping = { "age0_24": "0–24 years old", "age25_64": "25–64 years old", "age65plus": "65+ years old", } return mapping.get(age_group, age_group) def _summary_html(surv: float, risk: float, horizon_days: int, age_group: str) -> str: surv_pct = f"{surv*100:.1f}%" risk_pct = f"{risk*100:.1f}%" return f"""
Survival probability at {horizon_days} days
{surv_pct}
Mortality risk at {horizon_days} days
{risk_pct}
Model stratum: {age_group}. Output is model-based risk stratification and is not a clinical diagnosis.
""" # ----------------------------- # Prediction # ----------------------------- def predict(horizon: int, age: float, region_choice: str, sex: str, ethnicity: str, *vals): try: horizon = int(horizon) age = float(age) age_group = age_group_from_age(age) bundle = load_bundle(horizon, age_group) estimator = bundle["estimator"] feature_names = list(bundle["feature_names"]) cutoffs = dict(bundle["risk_cutoffs"]) horizon_days = int(bundle.get("horizon_days", horizon)) ui_feature_order = ( list(continuous_components.keys()) + list(other_numeric_components.keys()) + list(binary_components.keys()) ) if len(vals) != len(ui_feature_order): raise ValueError("Internal UI mismatch: inputs do not match expected feature list.") x = {f: v for f, v in zip(ui_feature_order, vals)} full = {f: 0 for f in feature_names} # Age column if present model_age_col = _find_age_col(feature_names) if model_age_col is not None: full[model_age_col] = age # Region one-hot for f in feature_names: if is_region_feature(f): full[f] = 0 if REGION_SUFFIXES: if not region_choice: raise ValueError("Region is required.") chosen = f"{REGION_PREFIX}{region_choice}" if chosen in full: full[chosen] = 1 else: lower_map = {_normalize(k): k for k in full.keys()} key_norm = _normalize(chosen) if key_norm in lower_map: full[lower_map[key_norm]] = 1 else: raise ValueError(f"Selected region '{region_choice}' does not match model columns.") # Demographics mapping _set_demographics(full, feature_names, sex=sex, ethnicity=ethnicity) # Remaining covariates for f, v in x.items(): if f not in full: continue if isinstance(v, (bool, np.bool_)): full[f] = int(v) else: full[f] = v X = pd.DataFrame([{c: full.get(c, np.nan) for c in feature_names}]) for c in X.columns: if X[c].dtype == bool: X[c] = X[c].astype(int) X[c] = pd.to_numeric(X[c], errors="coerce") if X.isna().any().any(): na_cols = [c for c in X.columns if X[c].isna().any()] raise ValueError("Missing/invalid values for: " + ", ".join(na_cols)) surv = survival_at_horizon(estimator, X, horizon_days=horizon_days) risk = 1.0 - surv gauge = _risk_gauge_html(risk) cat = risk_category_from_survival(surv) # Visual outputs badge = _risk_badge_html(cat) summary = _summary_html(surv, risk, horizon_days, age_group_label(age_group)) json_out = { "horizon_days": horizon_days, "age_group": age_group, "age_group_label": age_group_label(age_group), "survival_probability": float(round(surv, 6)), "mortality_risk": float(round(risk, 6)), "risk_category": cat, } return ( gauge, summary, badge, json_out, "" ) except Exception as e: return ("", "", "", {}, f"{type(e).__name__}: {str(e)}") # ----------------------------- # UI # ----------------------------- with gr.Blocks() as demo: gr.Markdown( "# Random Survival Forests: Online Prediction Model of Post-ICU Risk\n" "
\n\n" "This Online Prediction Model is based on a Post-ICU Risk Calculator that loads Random Survival Forest artifacts and returns survival probability and risk categories.\n\n" "The calculations are based on van Dongen DM, Gonzales Martinez RM, Nicodemo C, and Lasserson D: *Chronic Diseases and Multimorbidity on the Cardiorenal-Cerebrovascular-Frailty Axis Predict Early Post-ICU Mortality: Permutation-importance Insights from Age-stratified Random Survival Forests.*\n\n" "
\n\n" " **Do not enter identifiable patient data.** " ) with gr.Row(): horizon = gr.Radio( [30, 60], value=30, label="Survival horizon (days) after ICU discharge", ) age = gr.Slider(minimum=0, maximum=110, step=1, value=45, label="Age (years)") if REGION_SUFFIXES: region_choices = [(_pretty_token(s), s) for s in REGION_SUFFIXES] region_choice = gr.Dropdown(choices=region_choices, value=REGION_SUFFIXES[0], label="Region", interactive=True) else: region_choice = gr.Dropdown(choices=[], value=None, label="Region", interactive=False) with gr.Accordion("Demographics", open=True): sex = gr.Radio(["male", "female", "other"], value="male", label="Sex") ethnicity = gr.Radio( ["white", "asian", "black (african/caribbean)", "other"], value="white", label="Ethnicity", ) with gr.Accordion("Continuous clinical covariates", open=True): for comp in continuous_components.values(): comp.render() # Only show if it exists if len(other_numeric_components) > 0: with gr.Accordion("Additional numeric covariates", open=False): for comp in other_numeric_components.values(): comp.render() with gr.Accordion("Comorbidities", open=True): for comp in binary_components.values(): comp.render() predict_btn = gr.Button("Predict") gauge_html = gr.HTML() gr.Markdown("## Results") summary_html = gr.HTML() badge_html = gr.HTML() with gr.Accordion("Full output (JSON)", open=False): json_out = gr.JSON() err_out = gr.Textbox(label="Error details (if any)", interactive=False) ui_inputs = ( [horizon, age, region_choice, sex, ethnicity] + list(continuous_components.values()) + list(other_numeric_components.values()) + list(binary_components.values()) ) predict_btn.click( fn=predict, inputs=ui_inputs, outputs=[gauge_html, summary_html, badge_html, json_out, err_out], ) demo.launch(server_name="0.0.0.0", server_port=7860)