import numpy as np import pandas as pd import gradio as gr import joblib from functools import lru_cache from huggingface_hub import hf_hub_download # ----------------------------- # Configuration # ----------------------------- MODEL_REPO = "Rolando666/RSF-post-ICU-discharge" MODEL_FILES = { (30, "age0_24"): "rsf_30d_age0_24.joblib", (30, "age25_64"): "rsf_30d_age25_64.joblib", (30, "age65plus"): "rsf_30d_age65plus.joblib", (60, "age0_24"): "rsf_60d_age0_24.joblib", (60, "age25_64"): "rsf_60d_age25_64.joblib", (60, "age65plus"): "rsf_60d_age65plus.joblib", } REGION_PREFIX = "region_" # ----------------------------- # Labels + slider ranges (Table 1) # ----------------------------- LABELS = { "charlindex": "Charlson comorbidity index", "score": "Overall frailty score", "diag_count": "Number of diagnoses at ICU", "hospital_beds": "Hospital beds", "time_in_icu": "Time in ICU (months)", "time_in_i": "Time in ICU (months)", "ami": "Acute myocardial infarction", "chf": "Congestive heart failure", "pvd": "Peripheral vascular disease", "stroke": "Cerebrovascular disease / stroke", "copd": "Chronic pulmonary disease (COPD)", "dementia": "Dementia", "diabetes": "Diabetes", "renal": "Renal disease", "cancer": "Cancer", "liver_mild": "Mild liver disease", "liver_severe": "Severe liver disease", } CONTINUOUS_RANGES = { "charlindex": (0, 15, 1), "score": (0.0, 36.8, 0.1), "diag_count": (1, 20, 1), "hospital_beds": (293, 2605, 1), "time_in_icu": (0.0, 10.0, 0.1), "time_in_i": (0.0, 10.0, 0.1), } # Mean defaults from Table 1 CONTINUOUS_MEANS = { "charlindex": 3.46, "score": 5.83, # overall frailty score "diag_count": 12.55, "hospital_beds": 1244, "time_in_icu": 0.42, # months "time_in_i": 0.42, # truncated variant } BINARY_HINTS = { "ami", "chf", "pvd", "stroke", "copd", "dementia", "diabetes", "renal", "cancer", "liver_mild", "liver_severe", } SEX_COL_CANDIDATES = {"sexcat_male", "sexcat_female", "sexcat_other"} ETHN_COL_CANDIDATES = {"ethn_white_white", "ethn_other", "ethn_white", "ethnicity_white"} # ----------------------------- # Utilities # ----------------------------- def _normalize(s: str) -> str: return s.strip().lower().replace("-", "_").replace(" ", "_") def human_label(feature_name: str) -> str: key = _normalize(feature_name) return LABELS.get(key, LABELS.get(feature_name, feature_name)) def age_group_from_age(age: float) -> str: if age < 25: return "age0_24" if age < 65: return "age25_64" return "age65plus" def is_region_feature(name: str) -> bool: return _normalize(name).startswith(REGION_PREFIX) def is_binary_feature(name: str) -> bool: n = _normalize(name) if n in BINARY_HINTS: return True if n.startswith(REGION_PREFIX): return True if n.endswith("_yes") or n.endswith("_no"): return True return False def is_continuous_slider_feature(name: str) -> bool: n = _normalize(name) if n in CONTINUOUS_RANGES: return True for k in CONTINUOUS_RANGES: if k in n: return True return False def _find_age_col(feature_names: list[str]) -> str | None: lowered = {c: _normalize(c) for c in feature_names} for c, cl in lowered.items(): if cl in {"age", "age_years", "ageyrs", "edad"}: return c for c, cl in lowered.items(): if "age" in cl: return c return None def _region_suffixes(feature_names: list[str]) -> list[str]: sufs = [] for f in feature_names: if is_region_feature(f): sufs.append(f[len(REGION_PREFIX):]) seen, out = set(), [] for s in sufs: if s not in seen: out.append(s) seen.add(s) return out def _pretty_token(x: str) -> str: x = x.replace("_", " ").strip() return x[:1].upper() + x[1:] if x else x @lru_cache(maxsize=16) def load_bundle(horizon: int, age_group: str) -> dict: fname = MODEL_FILES[(horizon, age_group)] path = hf_hub_download(repo_id=MODEL_REPO, filename=fname) return joblib.load(path) @lru_cache(maxsize=1) def base_feature_names() -> list[str]: b = load_bundle(30, "age25_64") return list(b["feature_names"]) def survival_at_horizon(estimator, X: pd.DataFrame, horizon_days: int) -> float: sfs = estimator.predict_survival_function(X) # list of StepFunction (usually) sf0 = sfs[0] if callable(sf0): v = sf0(float(horizon_days)) return float(np.asarray(v).reshape(-1)[0]) if hasattr(sf0, "x") and hasattr(sf0, "y"): times = np.asarray(sf0.x, dtype=float) vals = np.asarray(sf0.y, dtype=float) idx = int(np.searchsorted(times, float(horizon_days), side="right") - 1) if idx < 0: return 1.0 return float(vals[idx]) # last resort sf_arr = estimator.predict_survival_function(X, return_array=True) times = getattr(estimator, "unique_times_", getattr(estimator, "event_times_", None)) if times is None: raise AttributeError("Model has no time grid attribute and survival is not a StepFunction.") times = np.asarray(times, dtype=float) idx = int(np.argmin(np.abs(times - float(horizon_days)))) return float(sf_arr[0, idx]) def risk_category_from_survival(surv: float) -> str: """ Categorize risk using survival probability thresholds (user-defined). surv in [0,1]. """ s = float(surv) if s >= 0.90: return "very low risk" if 0.80 <= s < 0.90: return "low risk" if 0.70 <= s < 0.80: return "low-moderate risk" if 0.60 <= s < 0.70: return "moderate-low risk" if 0.50 <= s < 0.60: return "moderate risk" if 0.40 <= s < 0.60: return "moderate-high risk" if 0.20 <= s < 0.40: return "high risk" return "very high risk" def _set_demographics(full: dict, feature_names: list[str], sex: str, ethnicity: str): # SEX mapping sex = (sex or "").lower().strip() sex_cols_present = [c for c in feature_names if _normalize(c) in SEX_COL_CANDIDATES] for c in sex_cols_present: full[c] = 0 if sex_cols_present: if sex == "male": for c in sex_cols_present: if _normalize(c) == "sexcat_male": full[c] = 1 elif sex == "female": for c in sex_cols_present: if _normalize(c) == "sexcat_female": full[c] = 1 else: for c in sex_cols_present: if _normalize(c) == "sexcat_other": full[c] = 1 # ETHNICITY mapping (model uses white vs non-white; UI shows more categories) eth = (ethnicity or "").lower().strip() eth_cols_present = [c for c in feature_names if _normalize(c) in ETHN_COL_CANDIDATES] for c in eth_cols_present: full[c] = 0 if eth_cols_present: if eth == "white": for c in eth_cols_present: if _normalize(c) in {"ethn_white_white", "ethn_white", "ethnicity_white"}: full[c] = 1 break else: # Non-white categories map to 0 for the white indicator; if an "other" column exists, set it. for c in eth_cols_present: if _normalize(c) == "ethn_other": full[c] = 1 break # if no ethn_other exists, leaving all-zeros is consistent with white-as-reference coding # ----------------------------- # Build components # ----------------------------- FEATURES = base_feature_names() AGE_COL = _find_age_col(FEATURES) REGION_SUFFIXES = _region_suffixes(FEATURES) HIDDEN_DEMOGRAPHIC_COLS = SEX_COL_CANDIDATES.union(ETHN_COL_CANDIDATES) continuous_features, binary_features, other_numeric_features = [], [], [] for f in FEATURES: if f == AGE_COL: continue if is_region_feature(f): continue if _normalize(f) in HIDDEN_DEMOGRAPHIC_COLS: continue if is_continuous_slider_feature(f) and not is_binary_feature(f): continuous_features.append(f) elif is_binary_feature(f): binary_features.append(f) else: other_numeric_features.append(f) continuous_components: dict[str, gr.Slider] = {} binary_components: dict[str, gr.Checkbox] = {} other_numeric_components: dict[str, gr.Number] = {} for f in continuous_features: nf = _normalize(f) key = nf if nf in CONTINUOUS_RANGES else None if key is None: for k in CONTINUOUS_RANGES: if k in nf: key = k break mn, mx, st = CONTINUOUS_RANGES.get(key, (0.0, 100.0, 1.0)) # Default: mean if available, otherwise mid-point mean_val = CONTINUOUS_MEANS.get(key, (mn + mx) / 2) # Clip mean to range and snap to step mean_val = max(mn, min(mx, mean_val)) if st: mean_val = round((mean_val - mn) / st) * st + mn mean_val = max(mn, min(mx, mean_val)) continuous_components[f] = gr.Slider( minimum=mn, maximum=mx, step=st, value=mean_val, label=human_label(f), ) for f in other_numeric_features: other_numeric_components[f] = gr.Number(label=human_label(f), value=0) for f in binary_features: binary_components[f] = gr.Checkbox(label=human_label(f), value=False) # ----------------------------- # Visual rendering helpers # ----------------------------- def _risk_badge_html(category: str) -> str: # no custom colors required; use simple semantic styling cat = (category or "").strip().lower() label = category # Minimal CSS using grayscale but clear hierarchy return f"""