From 084a7c084a74ad5c05640b6f3dbde28f42b45c16 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 5 Apr 2026 16:13:41 -0500 Subject: [PATCH 1/9] LOS as regression --- pyhealth/tasks/length_of_stay_prediction.py | 343 ++++++++++++++++++++ 1 file changed, 343 insertions(+) diff --git a/pyhealth/tasks/length_of_stay_prediction.py b/pyhealth/tasks/length_of_stay_prediction.py index 25e0c3121..9dbf3b24c 100644 --- a/pyhealth/tasks/length_of_stay_prediction.py +++ b/pyhealth/tasks/length_of_stay_prediction.py @@ -462,6 +462,349 @@ def __call__(self, patient: Patient) -> List[Dict]: return samples +class LengthOfStayRegressionMIMIC3(BaseTask): + """Predict ICU/hospital length of stay as a continuous value (days) using MIMIC-III. + + Unlike :class:`LengthOfStayPredictionMIMIC3` which bins LOS into 10 + categories, this task returns the raw float number of days so the model + can be trained and evaluated as a regression problem. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): conditions, procedures, drugs sequences. + output_schema (Dict[str, str]): ``{"los": "regression"}`` — a float + number of days. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import LengthOfStayRegressionMIMIC3 + >>> mimic3_base = MIMIC3Dataset( + ... root="/srv/local/data/physionet.org/files/mimiciii/1.4", + ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], + ... code_mapping={"ICD9CM": "CCSCM"}, + ... ) + >>> task = LengthOfStayRegressionMIMIC3() + >>> mimic3_sample = mimic3_base.set_task(task) + """ + + task_name: str = "LengthOfStayRegressionMIMIC3" + input_schema: Dict[str, str] = { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + } + output_schema: Dict[str, str] = {"los": "regression"} + + def __call__(self, patient: Patient) -> List[Dict]: + samples = [] + + admissions = patient.get_events(event_type="admissions") + if len(admissions) == 0: + return [] + + for admission in admissions: + diagnoses_events = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + conditions = [event.icd9_code for event in diagnoses_events] + + procedures_events = patient.get_events( + event_type="procedures_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + procedures = [event.icd9_code for event in procedures_events] + + prescriptions_events = patient.get_events( + event_type="prescriptions", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + drugs = [event.ndc for event in prescriptions_events] + + if len(conditions) * len(procedures) * len(drugs) == 0: + continue + + admit_time = admission.timestamp + discharge_time = datetime.strptime(admission.dischtime, "%Y-%m-%d %H:%M:%S") + los_days = float((discharge_time - admit_time).days) + + samples.append( + { + "visit_id": admission.hadm_id, + "patient_id": patient.patient_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "los": los_days, + } + ) + return samples + + +class LengthOfStayRegressionMIMIC4(BaseTask): + """Predict ICU/hospital length of stay as a continuous value (days) using MIMIC-IV. + + Unlike :class:`LengthOfStayPredictionMIMIC4` which bins LOS into 10 + categories, this task returns the raw float number of days. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): conditions, procedures, drugs sequences. + output_schema (Dict[str, str]): ``{"los": "regression"}`` — a float + number of days. + + Examples: + >>> from pyhealth.datasets import MIMIC4Dataset + >>> from pyhealth.tasks import LengthOfStayRegressionMIMIC4 + >>> mimic4_base = MIMIC4Dataset( + ... root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", + ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + ... code_mapping={"ICD10PROC": "CCSPROC"}, + ... ) + >>> task = LengthOfStayRegressionMIMIC4() + >>> mimic4_sample = mimic4_base.set_task(task) + """ + + task_name: str = "LengthOfStayRegressionMIMIC4" + input_schema: Dict[str, str] = { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + } + output_schema: Dict[str, str] = {"los": "regression"} + + def __call__(self, patient: Patient) -> List[Dict]: + samples = [] + + admissions = patient.get_events(event_type="admissions") + if len(admissions) == 0: + return [] + + for admission in admissions: + diagnoses_events = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + conditions = [ + f"{event.icd_version}_{event.icd_code}" for event in diagnoses_events + ] + + procedures_events = patient.get_events( + event_type="procedures_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + procedures = [ + f"{event.icd_version}_{event.icd_code}" for event in procedures_events + ] + + prescriptions_events = patient.get_events( + event_type="prescriptions", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + drugs = [event.ndc for event in prescriptions_events] + + if len(conditions) * len(procedures) * len(drugs) == 0: + continue + + admit_time = admission.timestamp + discharge_time = datetime.strptime(admission.dischtime, "%Y-%m-%d %H:%M:%S") + los_days = float((discharge_time - admit_time).days) + + samples.append( + { + "visit_id": admission.hadm_id, + "patient_id": patient.patient_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "los": los_days, + } + ) + return samples + + +class LengthOfStayRegressioneICU(BaseTask): + """Predict ICU length of stay as a continuous value (days) using eICU. + + Unlike :class:`LengthOfStayPredictioneICU` which bins LOS into 10 + categories, this task returns the raw float number of days. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): conditions, procedures, drugs sequences. + output_schema (Dict[str, str]): ``{"los": "regression"}`` — a float + number of days. + + Examples: + >>> from pyhealth.datasets import eICUDataset + >>> from pyhealth.tasks import LengthOfStayRegressioneICU + >>> dataset = eICUDataset( + ... root="/path/to/eicu-crd/2.0", + ... tables=["diagnosis", "medication", "physicalexam"], + ... ) + >>> task = LengthOfStayRegressioneICU() + >>> sample_dataset = dataset.set_task(task) + """ + + task_name: str = "LengthOfStayRegressioneICU" + input_schema: Dict[str, str] = { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + } + output_schema: Dict[str, str] = {"los": "regression"} + + def __call__(self, patient: Patient) -> List[Dict]: + samples = [] + + patient_stays = patient.get_events(event_type="patient") + if len(patient_stays) == 0: + return [] + + for stay in patient_stays: + stay_id = str(getattr(stay, "patientunitstayid", "")) + if not stay_id: + continue + + diagnosis_events = patient.get_events( + event_type="diagnosis", + filters=[("patientunitstayid", "==", stay_id)], + ) + conditions = [ + getattr(event, "diagnosisstring", "") + for event in diagnosis_events + if getattr(event, "diagnosisstring", None) + ] + + physicalexam_events = patient.get_events( + event_type="physicalexam", + filters=[("patientunitstayid", "==", stay_id)], + ) + procedures = [ + getattr(event, "physicalexamvalue", "") + for event in physicalexam_events + if getattr(event, "physicalexamvalue", None) + ] + + medication_events = patient.get_events( + event_type="medication", + filters=[("patientunitstayid", "==", stay_id)], + ) + drugs = [ + getattr(event, "drugname", "") + for event in medication_events + if getattr(event, "drugname", None) + ] + + if len(conditions) * len(procedures) * len(drugs) == 0: + continue + + unit_discharge_offset = getattr(stay, "unitdischargeoffset", None) + if unit_discharge_offset is None: + continue + + try: + los_minutes = int(unit_discharge_offset) + except (ValueError, TypeError): + continue + + los_days = float(los_minutes) / (60.0 * 24.0) + + samples.append( + { + "visit_id": stay_id, + "patient_id": patient.patient_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "los": los_days, + } + ) + return samples + + +class LengthOfStayRegressionOMOP(BaseTask): + """Predict hospital length of stay as a continuous value (days) using OMOP CDM. + + Unlike :class:`LengthOfStayPredictionOMOP` which bins LOS into 10 + categories, this task returns the raw float number of days. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): conditions, procedures, drugs sequences. + output_schema (Dict[str, str]): ``{"los": "regression"}`` — a float + number of days. + + Examples: + >>> from pyhealth.datasets import OMOPDataset + >>> from pyhealth.tasks import LengthOfStayRegressionOMOP + >>> omop_base = OMOPDataset( + ... root="https://storage.googleapis.com/pyhealth/synpuf1k_omop_cdm_5.2.2", + ... tables=["condition_occurrence", "procedure_occurrence", "drug_exposure"], + ... code_mapping={}, + ... ) + >>> task = LengthOfStayRegressionOMOP() + >>> omop_sample = omop_base.set_task(task) + """ + + task_name: str = "LengthOfStayRegressionOMOP" + input_schema: Dict[str, str] = { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + } + output_schema: Dict[str, str] = {"los": "regression"} + + def __call__(self, patient: Patient) -> List[Dict]: + samples = [] + + visit_occurrences = patient.get_events(event_type="visit_occurrence") + if len(visit_occurrences) == 0: + return [] + + for visit in visit_occurrences: + condition_events = patient.get_events( + event_type="condition_occurrence", + filters=[("visit_occurrence_id", "==", visit.visit_occurrence_id)], + ) + conditions = [event.condition_concept_id for event in condition_events] + + procedure_events = patient.get_events( + event_type="procedure_occurrence", + filters=[("visit_occurrence_id", "==", visit.visit_occurrence_id)], + ) + procedures = [event.procedure_concept_id for event in procedure_events] + + drug_events = patient.get_events( + event_type="drug_exposure", + filters=[("visit_occurrence_id", "==", visit.visit_occurrence_id)], + ) + drugs = [event.drug_concept_id for event in drug_events] + + if len(conditions) * len(procedures) * len(drugs) == 0: + continue + + admit_time = datetime.strptime( + visit.visit_start_datetime, "%Y-%m-%d %H:%M:%S" + ) + discharge_time = datetime.strptime( + visit.visit_end_datetime, "%Y-%m-%d %H:%M:%S" + ) + los_days = float((discharge_time - admit_time).days) + + samples.append( + { + "visit_id": visit.visit_occurrence_id, + "patient_id": patient.patient_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "los": los_days, + } + ) + return samples + + if __name__ == "__main__": from pyhealth.datasets import MIMIC3Dataset From 847fd2647172ff49d78710a6a0911dc193d7dbc4 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 5 Apr 2026 16:15:11 -0500 Subject: [PATCH 2/9] CKD surv task --- pyhealth/tasks/ckd_surv.py | 756 +++++++++++++++++++++++++++++++++++++ 1 file changed, 756 insertions(+) create mode 100644 pyhealth/tasks/ckd_surv.py diff --git a/pyhealth/tasks/ckd_surv.py b/pyhealth/tasks/ckd_surv.py new file mode 100644 index 000000000..deaf7c985 --- /dev/null +++ b/pyhealth/tasks/ckd_surv.py @@ -0,0 +1,756 @@ +from typing import Any, Dict, List, Literal, Union, Type, Optional +from datetime import timedelta +import polars as pl + +from .base_task import BaseTask +from pyhealth.processors import ( + SequenceProcessor, + TensorProcessor, + RawProcessor, +) + + +class MIMIC4CKDSurvAnalysis(BaseTask): + """Survival analysis for CKD progression on MIMIC-IV (CKD -> ESRD). + + This task prepares patient-level samples for survival modeling using + MIMIC-IV tables (patients, admissions, diagnoses_icd, labevents). It + supports three settings that change the input form: + + - "time_invariant": single-row snapshot per patient + - "time_variant": time series with a single modality stream + - "heterogeneous": time series with multiple lab modalities + + The time origin (t0) is the first available lab in the window. + Positive cases are censored at the ESRD date (inclusive-by-date) and + negatives at the last available lab. Durations are computed in days + from t0. + + Inputs and outputs by setting + - Common output (all settings): + - duration_days: Tensor (float), days between t0 and censoring + - has_esrd: Tensor (int, 0/1), whether ESRD occurred in the window + + - time_invariant inputs: + - demographics: Sequence ([age_group, gender_str, race]) + - age: Tensor (float) + - gender: Sequence (["M"|"F"]) for modeling as categorical + - baseline_egfr: Tensor (float), from a single target lab + - comorbidities: Sequence (ICD codes prior to t0) + + - time_variant inputs: + - demographics: Sequence ([age_group, gender_str]) + - age: Tensor (float) + - gender: Sequence (["M"|"F"]) for modeling as categorical + - lab_measurements: Raw list[dict], ordered by days since t0 + Each element includes: + - timestamp: int, days since t0 + - creatinine: float (if present) + - egfr: float (if present) + - has_esrd_step: int (0/1), only when ESRD-day exists + - extras via extra_lab_itemids (e.g. bun) + - bun_missing flag (0 present, 1 missing) + + - heterogeneous inputs: + - demographics, age, gender: same as time_variant + - lab_measurements: Raw list[dict] with multimodal labs per day + Each element includes (when present): + - timestamp: int + - creatinine: float + - egfr: float, derived from creatinine, age, gender + - protein: float, albumin: float + - egfr_missing/protein_missing/albumin_missing: int (0/1) + - has_esrd_step: int (0/1) on the ESRD day + - Any configured extras plus {name}_missing flags + + Parameters + - setting: one of ["time_invariant", "time_variant", "heterogeneous"] + - min_age: minimum age (years) to include in cohort (default 18) + - prediction_window_days: not used to truncate currently; reserved + - extra_lab_itemids: optional dict mapping feature name -> list of + labevents.itemid strings to include as extra modalities. For each + name, + two fields may appear in lab_measurements: {name} (float) and + {name}_missing (int 0/1). Values are aligned to days since t0. + + Notes + - eGFR uses the CKD-EPI 2021 formula with base coefficient 142. See + https://pubmed.ncbi.nlm.nih.gov/34554658/ + - Positives require at least one lab event recorded on the ESRD date, + matching the original pipeline semantics. + + Example + ------- + >>> from pyhealth.datasets import MIMIC4Dataset + >>> from pyhealth.tasks.ckd_surv import MIMIC4CKDSurvAnalysis + >>> dataset = MIMIC4Dataset( + ... root="/path/to/mimiciv/demo", + ... tables=[ + ... "patients", "admissions", "labevents", "diagnoses_icd" + ... ], + ... dev=True, + ... ) + >>> task = MIMIC4CKDSurvAnalysis( + ... setting="time_variant", + ... extra_lab_itemids={"bun": ["51006"]}, + ... ) + >>> dataset.set_task(task) + >>> samples = dataset.samples + >>> sample = samples[0] + >>> sorted(sample.keys()) + ['age', 'demographics', 'duration_days', 'gender', 'has_esrd', + 'lab_measurements', 'patient_id'] + >>> sample['lab_measurements'][0].keys() + dict_keys(['timestamp', 'egfr_missing', 'protein_missing', + 'albumin_missing', 'egfr', 'protein', 'albumin', + 'creatinine', 'has_esrd_step', 'bun_missing', 'bun']) + + """ + + # Private class variables for settings + _SURVIVAL_SETTINGS = ["time_invariant", "time_variant", "heterogeneous"] + _CKD_CODES = ["N183", "N184", "N185", "585.3", "585.4", "585.5"] + _ESRD_CODES = ["N186", "Z992", "585.6", "V42.0"] + _CREATININE_ITEMIDS = ["50912", "52546"] + _PROTEIN_ITEMIDS = ["50976"] + _ALBUMIN_ITEMIDS = ["50862"] + + # Gender constants (using MIMIC-IV native string values) + _MALE_GENDER = "M" # Male patients + _FEMALE_GENDER = "F" # Female patients + + # CKD-EPI 2021 equation constants (from pkgs.data.utils.calculate_eGFR) + _BASE_COEFFICIENT = 142 # Match original pipeline constant + _AGE_FACTOR = 0.993 # Annual age decline factor + _FEMALE_ADJUSTMENT = 1.018 # Female gender boost factor + + # Gender-specific creatinine thresholds and exponents + _MALE_CREAT_THRESHOLD = 0.9 # mg/dL + _FEMALE_CREAT_THRESHOLD = 0.7 # mg/dL + _MALE_ALPHA_EXPONENT = -0.411 # For creatinine ≤ 0.9 + _FEMALE_ALPHA_EXPONENT = -0.329 # For creatinine ≤ 0.7 + _BETA_EXPONENT = -1.209 # For creatinine > threshold (both genders) + + def __init__( + self, + setting: Literal[ + "time_invariant", "time_variant", "heterogeneous" + ] = "time_invariant", + min_age: int = 18, + prediction_window_days: int = 365 * 5, + extra_lab_itemids: Optional[Dict[str, List[str]]] = None, + ): + + if setting not in self._SURVIVAL_SETTINGS: + raise ValueError(f"Setting must be one of {self._SURVIVAL_SETTINGS}") + + self.setting = setting + self.min_age = min_age + self.prediction_window_days = prediction_window_days + self.task_name = f"MIMIC4CKDSurvAnalysis_{self.setting}" + # Optional extensibility: additional lab item IDs to extract from + # labevents. Dict maps feature name -> list of itemids. Values will + # appear inside lab_measurements as the feature name and a + # corresponding "{feature}_missing" flag (0 present, 1 missing). + self.extra_lab_itemids = extra_lab_itemids or {} + self.input_schema, self.output_schema = self._configure_schemas() + + def _configure_schemas( + self, + ) -> tuple[Dict[str, Union[str, Type]], Dict[str, Union[str, Type]]]: + """Configure schemas based on survival setting. + + Use registered processors: + - "sequence" for categorical lists + (e.g., demographics, gender, comorbidities) + - "tensor" for numeric values (e.g., age, eGFR, durations, labels) + """ + base_input: Dict[str, Union[str, Type]] = { + "demographics": SequenceProcessor, + "age": TensorProcessor, + "gender": SequenceProcessor, + } + + base_output: Dict[str, Union[str, Type]] = { + "duration_days": TensorProcessor, + "has_esrd": TensorProcessor, + } + + if self.setting == "time_invariant": + base_input.update( + { + "baseline_egfr": TensorProcessor, + "comorbidities": SequenceProcessor, + } + ) + elif self.setting == "time_variant": + # Use raw processor for time series list of dicts + base_input.update({"lab_measurements": RawProcessor}) + else: # heterogeneous + # Raw lab measurements; per-timestep missing flags inside each + # measurement element + base_input.update({"lab_measurements": RawProcessor}) + + return base_input, base_output + + def filter_patients(self, df: pl.LazyFrame) -> pl.LazyFrame: + """Filter for CKD patients with required lab data.""" + ckd_patients = ( + df.filter(pl.col("event_type") == "diagnoses_icd") + .filter(pl.col("diagnoses_icd/icd_code").is_in(self._CKD_CODES)) + .select("patient_id") + .unique() + ) + + lab_patients = ( + df.filter(pl.col("event_type") == "labevents") + .filter(pl.col("labevents/itemid").is_in(self._CREATININE_ITEMIDS)) + .select("patient_id") + .unique() + ) + + valid_patients = ckd_patients.join(lab_patients, on="patient_id", how="inner") + return df.filter( + pl.col("patient_id").is_in(valid_patients.select("patient_id")) + ) + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process patient for survival analysis.""" + # Get demographics + demographics = patient.get_events(event_type="patients") + if not demographics: + return [] + + demo = demographics[0] + age = int(demo.anchor_age or 0) + gender = (demo.gender or "").upper() + + if gender not in [self._MALE_GENDER, self._FEMALE_GENDER]: + return [] # Skip patients with invalid/missing gender + + if age < self.min_age: + return [] + + # Gather diagnoses + ckd_diagnoses = patient.get_events(event_type="diagnoses_icd") + ckd_events = [e for e in ckd_diagnoses if e.icd_code in self._CKD_CODES] + if not ckd_events: + return [] + esrd_events = [e for e in ckd_diagnoses if e.icd_code in self._ESRD_CODES] + esrd_date = min((e.timestamp for e in esrd_events), default=None) + + # Collect lab events relevant to the scenario and validate + lab_events = patient.get_events(event_type="labevents") + + def _valid_numeric(e): + try: + return ( + e.valuenum is not None + and float(e.valuenum) > 0 + and e.timestamp is not None + ) + except (ValueError, TypeError): + return False + + # Pre-compute extra lab events (validated numeric) + extra_events_map: Dict[str, List[Any]] = {} + for feat_name, itemids in self.extra_lab_itemids.items(): + if not itemids: + continue + extra_events_map[feat_name] = [ + e for e in lab_events if e.itemid in itemids and _valid_numeric(e) + ] + + # Select labs per scenario + if self.setting in ("time_invariant", "time_variant"): + creatinine_events = [ + e + for e in lab_events + if e.itemid in self._CREATININE_ITEMIDS and _valid_numeric(e) + ] + if not creatinine_events: + return [] + # t0 is first creatinine lab + t0 = min(e.timestamp for e in creatinine_events) + # For positives, keep labs up to ESRD date (inclusive by date) + if esrd_date is not None: + # Require at least one lab on the ESRD date to match original + # pipeline + labs_on_esrd_date = [ + e + for e in creatinine_events + if e.timestamp.date() == esrd_date.date() + ] + if not labs_on_esrd_date: + return [] + considered_creatinine = [ + e + for e in creatinine_events + if e.timestamp.date() <= esrd_date.date() + ] + has_esrd = 1 + duration_days = (esrd_date.date() - t0.date()).days + else: + considered_creatinine = creatinine_events + has_esrd = 0 + last_lab_time = max(e.timestamp for e in considered_creatinine) + duration_days = (last_lab_time.date() - t0.date()).days + + # Need at least two labs in the window + if len(considered_creatinine) < 2 or duration_days <= 0: + return [] + + # Dispatch per setting + if self.setting == "time_invariant": + return self._process_time_invariant( + patient, + t0, + age, + gender, + duration_days, + has_esrd, + considered_creatinine, + esrd_date, + ) + else: + # Filter extras by ESRD date if positive + if has_esrd and esrd_date is not None: + filtered_extras: Dict[str, List[Any]] = {} + for name, events in extra_events_map.items(): + filtered_extras[name] = [ + e for e in events if e.timestamp.date() <= esrd_date.date() + ] + else: + filtered_extras = extra_events_map + + return self._process_time_variant( + patient, + t0, + age, + gender, + duration_days, + has_esrd, + considered_creatinine, + esrd_date, + filtered_extras, + ) + + else: # heterogeneous + # Consider creatinine, protein, albumin + creatinine_events = [ + e + for e in lab_events + if e.itemid in self._CREATININE_ITEMIDS and _valid_numeric(e) + ] + protein_events = [ + e + for e in lab_events + if e.itemid in self._PROTEIN_ITEMIDS and _valid_numeric(e) + ] + albumin_events = [ + e + for e in lab_events + if e.itemid in self._ALBUMIN_ITEMIDS and _valid_numeric(e) + ] + + # Need creatinine to derive egfr at minimum + if not creatinine_events: + return [] + + # t0 is min across all available labs for this scenario + timestamps = [ + e.timestamp + for e in ( + creatinine_events + + protein_events + + albumin_events + + [ev for lst in extra_events_map.values() for ev in lst] + ) + if e.timestamp is not None + ] + if not timestamps: + return [] + t0 = min(timestamps) + + if esrd_date is not None: + # Require at least one lab on ESRD date + any_on_esrd = any( + e.timestamp.date() == esrd_date.date() + for e in ( + creatinine_events + + protein_events + + albumin_events + + [ev for lst in extra_events_map.values() for ev in lst] + ) + ) + if not any_on_esrd: + return [] + considered_creatinine = [ + e + for e in creatinine_events + if e.timestamp.date() <= esrd_date.date() + ] + considered_protein = [ + e for e in protein_events if e.timestamp.date() <= esrd_date.date() + ] + considered_albumin = [ + e for e in albumin_events if e.timestamp.date() <= esrd_date.date() + ] + considered_extras = { + name: [e for e in events if e.timestamp.date() <= esrd_date.date()] + for name, events in extra_events_map.items() + } + has_esrd = 1 + duration_days = (esrd_date.date() - t0.date()).days + else: + considered_creatinine = creatinine_events + considered_protein = protein_events + considered_albumin = albumin_events + considered_extras = extra_events_map + has_esrd = 0 + last_time = max( + [ + e.timestamp + for e in ( + considered_creatinine + + considered_protein + + considered_albumin + + [ev for lst in considered_extras.values() for ev in lst] + ) + ] + ) + duration_days = (last_time.date() - t0.date()).days + + # Ensure at least two total timepoints across any lab + total_events = len( + { + e.timestamp + for e in ( + considered_creatinine + + considered_protein + + considered_albumin + + [ev for lst in considered_extras.values() for ev in lst] + ) + } + ) + if total_events < 2 or duration_days <= 0: + return [] + + return self._process_heterogeneous( + patient, + t0, + age, + gender, + duration_days, + has_esrd, + considered_creatinine, + considered_protein, + considered_albumin, + esrd_date, + considered_extras, + ) + + def _process_time_invariant( + self, + patient, + t0, + age, + gender, + duration_days, + has_esrd, + considered_creatinine, + esrd_date, + ): + """ + Process for time-invariant analysis aligned with original + NON_TIME_VARIANT. + + - Positives: pick lab on ESRD date (last that day) and compute egfr + - Negatives: pick last available lab + """ + # Choose target creatinine event + if has_esrd and esrd_date is not None: + same_day_events = [ + e + for e in considered_creatinine + if e.timestamp.date() == esrd_date.date() + ] + if not same_day_events: + return [] + target_event = max(same_day_events, key=lambda x: x.timestamp) + else: + target_event = max(considered_creatinine, key=lambda x: x.timestamp) + + try: + creatinine_value = float(target_event.valuenum) + except (ValueError, TypeError): + return [] + if creatinine_value <= 0: + return [] + + egfr = self._calculate_egfr(creatinine_value, age, gender) + + # Comorbidities before first lab (t0) + diagnoses = patient.get_events(event_type="diagnoses_icd") + comorbidities = [ + e.icd_code + for e in diagnoses + if e.timestamp is not None and e.timestamp <= t0 and e.icd_code + ] + + # Race from admissions (optional meta) + admissions = patient.get_events(event_type="admissions") + race = admissions[0].race if admissions else "unknown" + + age_group = "elderly" if age >= 65 else "adult" + gender_str = "male" if gender == self._MALE_GENDER else "female" + + sample = { + "patient_id": patient.patient_id, + "demographics": [age_group, gender_str, race], + "baseline_egfr": egfr, + "comorbidities": comorbidities, + "age": float(age), + "gender": [gender], + "duration_days": float(duration_days), + "has_esrd": has_esrd, + } + return [sample] + + def _process_time_variant( + self, + patient, + t0, + age, + gender, + duration_days, + has_esrd, + considered_creatinine, + esrd_date, + extra_events_map: Optional[Dict[str, List[Any]]] = None, + ): + """ + Process for time-varying analysis aligned with original + TIME_VARIANT. + + Build series from first lab (t0) up to ESRD date (if positive) or last + lab (negative). + """ + # Build union of timepoints across creatinine and extras + extra_events_map = extra_events_map or {} + measurements_by_time: Dict[int, Dict[str, Any]] = {} + + def _ensure_day(day: int): + if day not in measurements_by_time: + measurements_by_time[day] = {"timestamp": day} + for name in extra_events_map.keys(): + measurements_by_time[day][f"{name}_missing"] = 1 + measurements_by_time[day][name] = 0.0 + + # Creatinine and egfr + considered_creatinine.sort(key=lambda x: x.timestamp) + for e in considered_creatinine: + try: + creatinine_value = float(e.valuenum) + if creatinine_value <= 0: + continue + except (ValueError, TypeError): + continue + + days_from_t0 = (e.timestamp.date() - t0.date()).days + egfr_value = self._calculate_egfr(creatinine_value, age, gender) + _ensure_day(days_from_t0) + m = measurements_by_time[days_from_t0] + m["egfr"] = egfr_value + m["creatinine"] = creatinine_value + if esrd_date is not None: + m["has_esrd_step"] = int(e.timestamp.date() == esrd_date.date()) + # Extras + for name, events in extra_events_map.items(): + for e in events: + day = (e.timestamp.date() - t0.date()).days + try: + val = float(e.valuenum) + except (ValueError, TypeError): + continue + if val <= 0: + continue + _ensure_day(day) + measurements_by_time[day][name] = val + measurements_by_time[day][f"{name}_missing"] = 0 + + lab_measurements = [ + measurements_by_time[d] for d in sorted(measurements_by_time.keys()) + ] + + age_group = "elderly" if age >= 65 else "adult" + gender_str = "male" if gender == self._MALE_GENDER else "female" + + sample = { + "patient_id": patient.patient_id, + "demographics": [age_group, gender_str], + "lab_measurements": lab_measurements, + "age": float(age), + "gender": [gender], + "duration_days": float(duration_days), + "has_esrd": has_esrd, + } + return [sample] + + def _process_heterogeneous( + self, + patient, + t0, + age, + gender, + duration_days, + has_esrd, + creatinine_events, + protein_events, + albumin_events, + esrd_date, + extra_events_map: Optional[Dict[str, List[Any]]] = None, + ): + """Process for heterogeneous analysis with per-timestep missing flags. + + Missing flags use names: egfr_missing, protein_missing, albumin_missing + (0/1). + """ + measurements_by_time: Dict[int, Dict[str, Any]] = {} + extra_events_map = extra_events_map or {} + + def _upsert(days: int, updates: Dict[str, Any]): + if days not in measurements_by_time: + measurements_by_time[days] = { + "timestamp": days, + "egfr_missing": 1, + "protein_missing": 1, + "albumin_missing": 1, + "egfr": 0.0, + "protein": 0.0, + "albumin": 0.0, + "creatinine": 0.0, + } + measurements_by_time[days].update(updates) + + # Within-window events already considered by caller + for e in creatinine_events: + days = (e.timestamp.date() - t0.date()).days + try: + cr = float(e.valuenum) + except (ValueError, TypeError): + continue + if cr <= 0: + continue + egfr = self._calculate_egfr(cr, age, gender) + _upsert(days, {"egfr": egfr, "creatinine": cr, "egfr_missing": 0}) + + for e in protein_events: + days = (e.timestamp.date() - t0.date()).days + try: + pv = float(e.valuenum) + except (ValueError, TypeError): + continue + if pv <= 0: + continue + _upsert(days, {"protein": pv, "protein_missing": 0}) + + for e in albumin_events: + days = (e.timestamp.date() - t0.date()).days + try: + av = float(e.valuenum) + except (ValueError, TypeError): + continue + if av <= 0: + continue + _upsert(days, {"albumin": av, "albumin_missing": 0}) + + # Extras + for name, events in extra_events_map.items(): + for e in events: + days = (e.timestamp.date() - t0.date()).days + try: + val = float(e.valuenum) + except (ValueError, TypeError): + continue + if val <= 0: + continue + _upsert(days, {name: val, f"{name}_missing": 0}) + + if len(measurements_by_time) < 2: + return [] + + lab_measurements: List[Dict[str, Any]] = [] + for days in sorted(measurements_by_time.keys()): + m = measurements_by_time[days] + if esrd_date is not None: + # Set step-level ESRD flag when day matches ESRD date + m["has_esrd_step"] = int( + (t0.date() + timedelta(days=days)) == esrd_date.date() + ) + lab_measurements.append(m) + + age_group = "elderly" if age >= 65 else "adult" + gender_str = "male" if gender == self._MALE_GENDER else "female" + + sample = { + "patient_id": patient.patient_id, + "demographics": [age_group, gender_str], + "lab_measurements": lab_measurements, + "age": float(age), + "gender": [gender], + "duration_days": float(duration_days), + "has_esrd": has_esrd, + } + return [sample] + + def _calculate_egfr(self, creatinine: float, age: int, gender: str) -> float: + """Calculate eGFR using simplified CKD-EPI equation. + + Implementation adapted from original MIMIC-IV analysis code: + - Source file: pkgs.data.utils.calculate_eGFR() + - Formula: CKD-EPI 2021 (https://pubmed.ncbi.nlm.nih.gov/34554658/) + - Original coefficient: 142 (updated from 141 in this implementation) + + CKD-EPI Formula Constants (from original utils.py): + - 0.9/0.7: Gender-specific creatinine thresholds (mg/dL) + - 0.993: Age factor per year + - 1.018: Female gender adjustment factor + - -0.411/-0.329: Alpha exponents for creatinine ≤ threshold + - -1.209: Beta exponent for creatinine > threshold (both genders) + + Args: + creatinine: Serum creatinine in mg/dL (MIMIC-IV native units) + age: Patient age in years + gender: Gender string ('M' for male, 'F' for female) + + Returns: + Estimated GFR in mL/min/1.73m² + """ + # Validate inputs (following original validation) + if creatinine <= 0: + raise ValueError(f"Invalid creatinine value: {creatinine}") + if gender not in [self._MALE_GENDER, self._FEMALE_GENDER]: + raise ValueError(f"Invalid gender: {gender}") + + # Ensure creatinine is float for calculations + creatinine = float(creatinine) + + if gender == self._MALE_GENDER: # Male + return ( + self._BASE_COEFFICIENT + * min(creatinine / self._MALE_CREAT_THRESHOLD, 1) + ** self._MALE_ALPHA_EXPONENT + * max(creatinine / self._MALE_CREAT_THRESHOLD, 1) ** self._BETA_EXPONENT + * self._AGE_FACTOR**age + ) + else: # Female (gender == self._FEMALE_GENDER) + return ( + self._BASE_COEFFICIENT + * min(creatinine / self._FEMALE_CREAT_THRESHOLD, 1) + ** self._FEMALE_ALPHA_EXPONENT + * max(creatinine / self._FEMALE_CREAT_THRESHOLD, 1) + ** self._BETA_EXPONENT + * self._AGE_FACTOR**age + * self._FEMALE_ADJUSTMENT + ) From e490415a098319481c6793f46c3872ee7dd1fc3c Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 5 Apr 2026 16:19:45 -0500 Subject: [PATCH 3/9] survival metrics --- pyhealth/metrics/survival.py | 366 +++++++++++++++++++++++++++++++++++ 1 file changed, 366 insertions(+) create mode 100644 pyhealth/metrics/survival.py diff --git a/pyhealth/metrics/survival.py b/pyhealth/metrics/survival.py new file mode 100644 index 000000000..a58ee8113 --- /dev/null +++ b/pyhealth/metrics/survival.py @@ -0,0 +1,366 @@ +"""Survival analysis metrics for PyHealth. + +Implements Harrell's concordance index (C-index) and the inverse +probability of censoring weighted (IPCW) Brier score for evaluating +time-to-event / survival models. Both are computed in pure NumPy so +there is no extra dependency beyond what PyHealth already requires. + +Typical usage +------------- +>>> from pyhealth.metrics import survival_metrics_fn +>>> import numpy as np +>>> times = np.array([5, 10, 3, 8, 15]) +>>> events = np.array([1, 0, 1, 1, 0]) # 1 = event occurred +>>> scores = np.array([0.9, 0.3, 0.8, 0.7, 0.2]) # higher = higher risk +>>> survival_metrics_fn(times, scores, events, metrics=["c_index"]) +{'c_index': 0.9166...} +""" + +from typing import Dict, List, Optional, Tuple + +import numpy as np + +def _kaplan_meier( + times: np.ndarray, + events: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray]: + """Compute the Kaplan–Meier survival estimate. + + Args: + times: 1-D array of observed times. + events: 1-D binary array; 1 = event occurred, 0 = censored. + + Returns: + (unique_times, survival_probs): KM estimate at each unique event time, + with survival_probs[i] = P(T > unique_times[i]). + """ + order = np.argsort(times) + times_sorted = times[order] + events_sorted = events[order] + + n = len(times_sorted) + unique_times = [] + survival_probs = [] + s = 1.0 + + i = 0 + while i < n: + t = times_sorted[i] + # Collect all observations at this time + j = i + while j < n and times_sorted[j] == t: + j += 1 + n_at_risk = n - i + n_events = int(events_sorted[i:j].sum()) + if n_events > 0: + s *= (1.0 - n_events / n_at_risk) + unique_times.append(t) + survival_probs.append(s) + i = j + + return np.array(unique_times, dtype=float), np.array(survival_probs, dtype=float) + + +def _km_at_times( + km_times: np.ndarray, + km_probs: np.ndarray, + query_times: np.ndarray, +) -> np.ndarray: + """Evaluate the KM step function at arbitrary query times. + + Uses the last-value-carries-forward convention (left-continuous). + + Args: + km_times: Unique event times from :func:`_kaplan_meier`. + km_probs: KM survival probabilities at those times. + query_times: Times at which to evaluate. + + Returns: + KM survival probability at each query time. + """ + result = np.ones(len(query_times), dtype=float) + for k, t in enumerate(query_times): + idx = np.searchsorted(km_times, t, side="right") - 1 + if idx >= 0: + result[k] = km_probs[idx] + return result + +def concordance_index_censored( + event_times: np.ndarray, + predicted_scores: np.ndarray, + event_observed: np.ndarray, +) -> float: + """Harrell's concordance index for right-censored survival data. + + Counts all comparable pairs (i, j) where subject i had the event + before subject j (t_i < t_j and event_i = 1). A pair is concordant + when the model assigns a higher risk score to i. + + Complexity: O(n²) in memory for the boolean comparison matrices, which + is fine up to a few thousand samples. For very large n, sub-sample or + use an O(n log n) implementation. + + Args: + event_times: 1-D array of observed times (shape ``(n,)``). + predicted_scores: 1-D array of predicted risk scores — *higher means + higher risk / shorter expected survival* (shape ``(n,)``). + event_observed: 1-D binary array; 1 = event, 0 = censored (shape ``(n,)``). + + Returns: + C-index in [0, 1]. 0.5 means random; 1.0 means perfect ranking. + + Raises: + ValueError: If inputs have incompatible shapes or no comparable pairs + exist. + """ + event_times = np.asarray(event_times, dtype=float) + predicted_scores = np.asarray(predicted_scores, dtype=float) + event_observed = np.asarray(event_observed, dtype=bool) + + if not (event_times.shape == predicted_scores.shape == event_observed.shape): + raise ValueError("event_times, predicted_scores, and event_observed must have the same shape.") + if event_times.ndim != 1: + raise ValueError("Inputs must be 1-D arrays.") + + t = event_times + r = predicted_scores + e = event_observed + + # Broadcasting: t_i[:, None] vs t_j[None, :] + t_i = t[:, np.newaxis] # (n, 1) + t_j = t[np.newaxis, :] # (1, n) + r_i = r[:, np.newaxis] + r_j = r[np.newaxis, :] + e_i = e[:, np.newaxis] + + # Comparable: i had event first (strict), j not necessarily uncensored + comparable = (t_i < t_j) & e_i # (n, n) + concordant = comparable & (r_i > r_j) + tied_risk = comparable & (r_i == r_j) + + n_comparable = int(comparable.sum()) + if n_comparable == 0: + return 0.5 + + c_index = (float(concordant.sum()) + 0.5 * float(tied_risk.sum())) / n_comparable + return c_index + + +def brier_score_survival( + event_times: np.ndarray, + predicted_survival: np.ndarray, + event_observed: np.ndarray, + eval_time: float, +) -> float: + """IPCW Brier score for survival at a single evaluation time. + + Uses inverse probability of censoring weighting (IPCW) so that the + score is unbiased under informative censoring. The censoring + distribution G(t) is estimated non-parametrically via Kaplan–Meier on + the *censored* events. + + Reference: + Graf et al. (1999), "Assessment and comparison of prognostic + classification schemes for survival data". + + Args: + event_times: 1-D array of observed times. + predicted_survival: 1-D array of predicted P(T > eval_time | X), + i.e. survival probability at ``eval_time`` for each subject. + event_observed: 1-D binary array; 1 = event, 0 = censored. + eval_time: The time horizon at which to evaluate the score. + + Returns: + Brier score at ``eval_time``; 0 = perfect, 0.25 = random baseline + for a balanced dataset. + """ + event_times = np.asarray(event_times, dtype=float) + predicted_survival = np.asarray(predicted_survival, dtype=float) + event_observed = np.asarray(event_observed, dtype=bool) + + n = len(event_times) + + # Fit KM on the censoring distribution (flip event indicator) + censoring_observed = ~event_observed + km_times, km_probs = _kaplan_meier(event_times, censoring_observed.astype(float)) + + # G(t_i) for each subject and G(eval_time) + g_ti = _km_at_times(km_times, km_probs, event_times) + g_eval = _km_at_times(km_times, km_probs, np.array([eval_time]))[0] + + # Avoid division by zero + g_ti = np.maximum(g_ti, 1e-8) + g_eval = max(g_eval, 1e-8) + + s_hat = predicted_survival + + # IPCW Brier score terms + # Term 1: t_i <= t*, event_i = 1 → (0 - S_hat)² / G(t_i) + # Term 2: t_i > t* → (1 - S_hat)² / G(t*) + indicator_event_before = (event_times <= eval_time) & event_observed + indicator_alive = event_times > eval_time + + bs = ( + np.sum(s_hat**2 * indicator_event_before / g_ti) + + np.sum((1.0 - s_hat)**2 * indicator_alive / g_eval) + ) / n + + return float(bs) + + +def integrated_brier_score( + event_times: np.ndarray, + predicted_survival_fn, + event_observed: np.ndarray, + time_grid: Optional[np.ndarray] = None, + n_time_points: int = 100, +) -> float: + """Integrated Brier Score (IBS) over a time grid. + + Computes :func:`brier_score_survival` at each point in ``time_grid`` + and integrates via the trapezoid rule. + + Args: + event_times: 1-D observed times. + predicted_survival_fn: Callable ``(time_grid: np.ndarray) -> + np.ndarray`` of shape ``(n_subjects, len(time_grid))``, giving + predicted survival probabilities for each subject at each + evaluation time. + event_observed: 1-D binary event indicator. + time_grid: Times at which to evaluate the Brier score. Defaults + to ``n_time_points`` evenly-spaced points between the 10th + and 90th percentile of observed event times. + n_time_points: Number of points to use when ``time_grid`` is None. + + Returns: + Integrated Brier Score normalised by the time range, + i.e. ``IBS / (t_max - t_min)`` ∈ [0, 1]. + """ + event_times = np.asarray(event_times, dtype=float) + event_observed = np.asarray(event_observed, dtype=bool) + + if time_grid is None: + t_min = float(np.percentile(event_times, 10)) + t_max = float(np.percentile(event_times, 90)) + time_grid = np.linspace(t_min, t_max, n_time_points) + + # predicted_survival_fn returns (n_subjects, len(time_grid)) + survival_matrix = np.asarray(predicted_survival_fn(time_grid)) # (n, T) + + brier_scores = [] + for k, t in enumerate(time_grid): + s_at_t = survival_matrix[:, k] + bs = brier_score_survival(event_times, s_at_t, event_observed, t) + brier_scores.append(bs) + + brier_scores = np.array(brier_scores) + t_range = float(time_grid[-1] - time_grid[0]) + if t_range <= 0: + return float(np.mean(brier_scores)) + + ibs = float(np.trapz(brier_scores, time_grid)) / t_range + return ibs + + +def survival_metrics_fn( + event_times: np.ndarray, + predicted_scores: np.ndarray, + event_observed: np.ndarray, + metrics: Optional[List[str]] = None, + eval_time: Optional[float] = None, + predicted_survival: Optional[np.ndarray] = None, +) -> Dict[str, float]: + """Compute survival analysis evaluation metrics. + + This is the main entry point, analogous to + :func:`pyhealth.metrics.regression_metrics_fn` but for right-censored + survival data. + + Args: + event_times: 1-D array of observed times (n,). + predicted_scores: 1-D array of predicted risk scores — higher values + mean *higher risk* / shorter expected survival (n,). + event_observed: 1-D binary array; 1 = event occurred, 0 = censored (n,). + metrics: List of metric names to compute. Accepted values: + + - ``"c_index"``: Harrell's concordance index (default). + - ``"brier_score"``: IPCW Brier score at ``eval_time`` (requires + ``eval_time`` and ``predicted_survival``). + + Defaults to ``["c_index"]``. + eval_time: Time horizon used for ``"brier_score"``. Required when + ``"brier_score"`` is in ``metrics``. + predicted_survival: 1-D array of predicted P(T > eval_time | X) + for each subject. Required when ``"brier_score"`` is in + ``metrics``. + + Returns: + Dictionary mapping metric name → float value. + + Examples: + >>> import numpy as np + >>> from pyhealth.metrics import survival_metrics_fn + >>> times = np.array([5.0, 10.0, 3.0, 8.0, 15.0]) + >>> events = np.array([1, 0, 1, 1, 0 ]) + >>> scores = np.array([0.9, 0.3, 0.8, 0.7, 0.2]) + >>> survival_metrics_fn(times, scores, events) + {'c_index': ...} + """ + if metrics is None: + metrics = ["c_index"] + + event_times = np.asarray(event_times, dtype=float).flatten() + predicted_scores = np.asarray(predicted_scores, dtype=float).flatten() + event_observed = np.asarray(event_observed, dtype=float).flatten() + + output: Dict[str, float] = {} + + for metric in metrics: + if metric == "c_index": + output["c_index"] = concordance_index_censored( + event_times, predicted_scores, event_observed.astype(bool) + ) + elif metric == "brier_score": + if eval_time is None: + raise ValueError( + "'eval_time' must be provided when computing 'brier_score'." + ) + if predicted_survival is None: + raise ValueError( + "'predicted_survival' must be provided when computing 'brier_score'." + ) + ps = np.asarray(predicted_survival, dtype=float).flatten() + output["brier_score"] = brier_score_survival( + event_times, ps, event_observed.astype(bool), float(eval_time) + ) + else: + raise ValueError( + f"Unknown survival metric: '{metric}'. " + "Accepted values: 'c_index', 'brier_score'." + ) + + return output + + +if __name__ == "__main__": + rng = np.random.default_rng(42) + n = 200 + true_times = rng.exponential(scale=10, size=n) + censor_times = rng.exponential(scale=15, size=n) + observed_times = np.minimum(true_times, censor_times) + events = (true_times <= censor_times).astype(int) + risk_scores = 1.0 / true_times + rng.normal(0, 0.1, n) + + print(survival_metrics_fn(observed_times, risk_scores, events)) + eval_t = float(np.median(observed_times[events == 1])) + surv_at_t = np.clip(1.0 - risk_scores / risk_scores.max(), 0.01, 0.99) + print( + survival_metrics_fn( + observed_times, + risk_scores, + events, + metrics=["c_index", "brier_score"], + eval_time=eval_t, + predicted_survival=surv_at_t, + ) + ) From fa31f9f962ceaba2edd428c50c2b81fec9a598dc Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 5 Apr 2026 16:21:26 -0500 Subject: [PATCH 4/9] minor fix --- pyhealth/metrics/survival.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyhealth/metrics/survival.py b/pyhealth/metrics/survival.py index a58ee8113..74e6747c0 100644 --- a/pyhealth/metrics/survival.py +++ b/pyhealth/metrics/survival.py @@ -2,8 +2,7 @@ Implements Harrell's concordance index (C-index) and the inverse probability of censoring weighted (IPCW) Brier score for evaluating -time-to-event / survival models. Both are computed in pure NumPy so -there is no extra dependency beyond what PyHealth already requires. +time-to-event / survival models. Typical usage ------------- From a22662622122a8d432cf4b37374b713ba7e16421 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 5 Apr 2026 16:23:21 -0500 Subject: [PATCH 5/9] metabric dataset --- pyhealth/datasets/metabric.py | 198 ++++++++++++++++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 pyhealth/datasets/metabric.py diff --git a/pyhealth/datasets/metabric.py b/pyhealth/datasets/metabric.py new file mode 100644 index 000000000..e0a923d08 --- /dev/null +++ b/pyhealth/datasets/metabric.py @@ -0,0 +1,198 @@ +"""METABRIC dataset for PyHealth. + +METABRIC (Molecular Taxonomy of Breast Cancer International Consortium) is a +landmark breast cancer study combining clinical and genomic data for ~2,000 +patients with long-term follow-up. + +The dataset is publicly available from: + - cBioPortal: https://www.cbioportal.org/study/summary?id=brca_metabric + Download "All clinical data" (data_clinical_patient.txt) and optionally + "CNA data" or "mRNA expression" tables. + - Kaggle (pre-processed): search "METABRIC breast cancer clinical" +""" + +import logging +import os +from pathlib import Path +from typing import List, Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class METABRICDataset(BaseDataset): + """METABRIC breast cancer clinical dataset. + + Each patient has a single clinical record with demographics, treatment + indicators, tumour characteristics, and two survival endpoints: + + - **Overall Survival (OS)**: time ``OS_MONTHS`` and status ``OS_STATUS`` + (0 = living, 1 = died from cancer or unknown cause). + - **Relapse-Free Survival (RFS)**: time ``RFS_MONTHS`` and status + ``RFS_STATUS`` (0 = no relapse, 1 = relapse or death). + + Args: + root: Directory containing the processed ``metabric_clinical.csv`` + (or the raw cBioPortal ``data_clinical_patient.txt``). + tables: Additional tables to load beyond the default ``["metabric"]``. + dataset_name: Optional dataset name; defaults to ``"metabric"``. + config_path: Optional path to a YAML config; defaults to the bundled + ``configs/metabric.yaml``. + **kwargs: Passed through to :class:`~pyhealth.datasets.BaseDataset`. + + Examples: + >>> from pyhealth.datasets import METABRICDataset + >>> dataset = METABRICDataset(root="/path/to/metabric") + >>> dataset.stats() + >>> samples = dataset.set_task(task) + """ + + def __init__( + self, + root: str, + tables: Optional[List[str]] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + if config_path is None: + config_path = Path(__file__).parent / "configs" / "metabric.yaml" + + processed_csv = os.path.join(root, "metabric_clinical.csv") + if not os.path.exists(processed_csv): + logger.info( + "metabric_clinical.csv not found — attempting to prepare from raw data." + ) + self.prepare_metadata(root) + + default_tables = ["metabric"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "metabric", + config_path=str(config_path), + **kwargs, + ) + + @staticmethod + def prepare_metadata(root: str) -> None: + """Convert raw cBioPortal download to the processed CSV. + + Looks for ``data_clinical_patient.txt`` (tab-separated, with comment + header rows) inside ``root`` and writes + ``metabric_clinical.csv``. + + Args: + root: Directory to search for the raw file and write output. + """ + raw_candidates = [ + "data_clinical_patient.txt", + "METABRIC_RNA_Mutation.csv", + "metabric.csv", + ] + raw_file: Optional[str] = None + for fname in raw_candidates: + candidate = os.path.join(root, fname) + if os.path.exists(candidate): + raw_file = candidate + break + + output_path = os.path.join(root, "metabric_clinical.csv") + + if raw_file is None: + logger.warning( + f"No raw METABRIC file found in {root}. " + "Please download 'data_clinical_patient.txt' from " + "https://www.cbioportal.org/study/summary?id=brca_metabric " + "and place it in the root directory." + ) + # Write an empty placeholder so BaseDataset doesn't crash + pd.DataFrame( + columns=[ + "PATIENT_ID", + "AGE_AT_DIAGNOSIS", + "OS_MONTHS", + "OS_STATUS", + "RFS_MONTHS", + "RFS_STATUS", + "INFERRED_MENOPAUSAL_STATE", + "TUMOR_SIZE", + "TUMOR_STAGE", + "NPI", + "CELLULARITY", + "CHEMOTHERAPY", + "ER_IHC", + "HER2_SNP6", + "HORMONE_THERAPY", + "INTCLUST", + "ONCOTREE_CODE", + "RADIO_THERAPY", + "THREEGENE", + "GRADE", + "TYPE_OF_BREAST_SURGERY", + "PR_STATUS", + "HER2_STATUS", + ] + ).to_csv(output_path, index=False) + return + + logger.info(f"Processing METABRIC raw file: {raw_file}") + + if raw_file.endswith(".txt"): + # cBioPortal format: skip lines starting with '#', tab-separated + df = pd.read_csv(raw_file, sep="\t", comment="#", low_memory=False) + else: + df = pd.read_csv(raw_file, low_memory=False) + + # Normalise column names to upper-case and replace spaces/hyphens + df.columns = ( + df.columns.str.upper() + .str.strip() + .str.replace(" ", "_", regex=False) + .str.replace("-", "_", regex=False) + .str.replace("(", "", regex=False) + .str.replace(")", "", regex=False) + ) + + # Common alternative column name mappings + rename = { + "PATIENT_IDENTIFIER": "PATIENT_ID", + "OVERALL_SURVIVAL_STATUS": "OS_STATUS", + "OVERALL_SURVIVAL_MONTHS": "OS_MONTHS", + "RELAPSE_FREE_STATUS": "RFS_STATUS", + "RELAPSE_FREE_STATUS_MONTHS": "RFS_MONTHS", + "AGE_AT_INITIAL_PATHOLOGIC_DIAGNOSIS": "AGE_AT_DIAGNOSIS", + "INFERRED_MENOPAUSAL_STATE": "INFERRED_MENOPAUSAL_STATE", + "NOTTINGHAM_PROGNOSTIC_INDEX": "NPI", + "3_GENE_CLASSIFIER_SUBTYPE": "THREEGENE", + "TYPE_OF_BREAST_SURGERY": "TYPE_OF_BREAST_SURGERY", + } + df = df.rename(columns={k: v for k, v in rename.items() if k in df.columns}) + + # Parse OS_STATUS: cBioPortal stores "0:LIVING" / "1:DECEASED" + if "OS_STATUS" in df.columns: + df["OS_STATUS"] = ( + df["OS_STATUS"] + .astype(str) + .str.extract(r"^(\d+)", expand=False) + .astype(float) + ) + if "RFS_STATUS" in df.columns: + df["RFS_STATUS"] = ( + df["RFS_STATUS"] + .astype(str) + .str.extract(r"^(\d+)", expand=False) + .astype(float) + ) + + if "PATIENT_ID" not in df.columns: + df["PATIENT_ID"] = df.index.astype(str) + + df = df.drop_duplicates(subset=["PATIENT_ID"]) + df.to_csv(output_path, index=False) + logger.info(f"Saved {len(df)} METABRIC records to {output_path}") From c220f38b70509e38a58e27d32edfbe6cfe47ef2e Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 5 Apr 2026 16:24:57 -0500 Subject: [PATCH 6/9] metabric yaml --- pyhealth/datasets/configs/metabric.yaml | 29 +++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 pyhealth/datasets/configs/metabric.yaml diff --git a/pyhealth/datasets/configs/metabric.yaml b/pyhealth/datasets/configs/metabric.yaml new file mode 100644 index 000000000..6da373421 --- /dev/null +++ b/pyhealth/datasets/configs/metabric.yaml @@ -0,0 +1,29 @@ +version: "1.0" +tables: + metabric: + file_path: "metabric_clinical.csv" + patient_id: "PATIENT_ID" + timestamp: null + attributes: + - "AGE_AT_DIAGNOSIS" + - "OS_MONTHS" + - "OS_STATUS" + - "RFS_MONTHS" + - "RFS_STATUS" + - "INFERRED_MENOPAUSAL_STATE" + - "TUMOR_SIZE" + - "TUMOR_STAGE" + - "NPI" + - "CELLULARITY" + - "CHEMOTHERAPY" + - "ER_IHC" + - "HER2_SNP6" + - "HORMONE_THERAPY" + - "INTCLUST" + - "ONCOTREE_CODE" + - "RADIO_THERAPY" + - "THREEGENE" + - "GRADE" + - "TYPE_OF_BREAST_SURGERY" + - "PR_STATUS" + - "HER2_STATUS" From 9ae756e180a4ae032f079c3617b33c3a47136460 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 5 Apr 2026 16:27:10 -0500 Subject: [PATCH 7/9] metabric survival os/rfs --- pyhealth/tasks/metabric_survival.py | 260 ++++++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 pyhealth/tasks/metabric_survival.py diff --git a/pyhealth/tasks/metabric_survival.py b/pyhealth/tasks/metabric_survival.py new file mode 100644 index 000000000..ab8923379 --- /dev/null +++ b/pyhealth/tasks/metabric_survival.py @@ -0,0 +1,260 @@ +"""Survival analysis tasks for the METABRIC breast cancer dataset. + +Two ready-to-use task classes are provided: + +- :class:`METABRICSurvivalOS` — overall survival (OS_MONTHS / OS_STATUS). +- :class:`METABRICSurvivalRFS` — relapse-free survival (RFS_MONTHS / RFS_STATUS). + +Both share the same feature set and differ only in the survival endpoint. + +Usage +----- +>>> from pyhealth.datasets import METABRICDataset +>>> from pyhealth.tasks import METABRICSurvivalOS +>>> +>>> dataset = METABRICDataset(root="/path/to/metabric") +>>> task = METABRICSurvivalOS() +>>> samples = dataset.set_task(task) +>>> samples[0].keys() +dict_keys(['patient_id', 'clinical_features', 'treatment_features', + 'os_months', 'os_status']) +""" + +import logging +from typing import Any, Dict, List, Optional + +from .base_task import BaseTask + +logger = logging.getLogger(__name__) + +def _safe_float(value: Any) -> Optional[float]: + """Return float(value) or None if conversion fails.""" + if value is None: + return None + try: + return float(value) + except (ValueError, TypeError): + return None + + +def _safe_str(value: Any) -> Optional[str]: + """Return stripped string or None if empty.""" + if value is None: + return None + s = str(value).strip() + return s if s and s.lower() not in ("nan", "none", "") else None + + +def _extract_clinical_features(event: Any) -> List[str]: + """Build a list of descriptive clinical feature tokens. + + Each token encodes both the feature name and its value so that the + downstream :class:`~pyhealth.processors.SequenceProcessor` can learn + an embedding per (name, value) combination. + """ + features: List[str] = [] + + # Continuous → discretised bucket tokens + age = _safe_float(getattr(event, "AGE_AT_DIAGNOSIS", None)) + if age is not None: + bucket = "young" if age < 45 else ("middle" if age < 65 else "elderly") + features.append(f"age_group_{bucket}") + + tumor_size = _safe_float(getattr(event, "TUMOR_SIZE", None)) + if tumor_size is not None: + size_cat = "small" if tumor_size < 20 else ("medium" if tumor_size < 50 else "large") + features.append(f"tumor_size_{size_cat}") + + npi = _safe_float(getattr(event, "NPI", None)) + if npi is not None: + npi_cat = "good" if npi < 3.4 else ("moderate" if npi < 5.4 else "poor") + features.append(f"npi_{npi_cat}") + + grade = _safe_float(getattr(event, "GRADE", None)) + if grade is not None: + features.append(f"grade_{int(grade)}") + + tumor_stage = _safe_float(getattr(event, "TUMOR_STAGE", None)) + if tumor_stage is not None: + features.append(f"stage_{int(tumor_stage)}") + + # Categorical tokens (raw value) + for field in ( + "INFERRED_MENOPAUSAL_STATE", + "CELLULARITY", + "ER_IHC", + "HER2_SNP6", + "INTCLUST", + "ONCOTREE_CODE", + "THREEGENE", + "TYPE_OF_BREAST_SURGERY", + "PR_STATUS", + "HER2_STATUS", + ): + val = _safe_str(getattr(event, field, None)) + if val is not None: + features.append(f"{field.lower()}_{val.lower()}") + + return features + + +def _extract_treatment_features(event: Any) -> List[str]: + """Build treatment indicator tokens.""" + features: List[str] = [] + for field in ("CHEMOTHERAPY", "HORMONE_THERAPY", "RADIO_THERAPY"): + val = _safe_str(getattr(event, field, None)) + if val is not None: + features.append(f"{field.lower()}_{val.lower()}") + return features + +class METABRICSurvivalOS(BaseTask): + """Overall survival prediction task for the METABRIC dataset. + + Predicts overall survival (OS) for breast cancer patients: + + - **os_months**: continuous time-to-event / censoring time in months + (regression label). + - **os_status**: binary event indicator — 1 if the patient died, 0 if + alive at last follow-up (censored). + + Task Schema: + Input: + - clinical_features: sequence of tokenised clinical attributes + (age group, tumour size, grade, stage, ER/HER2 status, etc.) + - treatment_features: sequence of treatment indicator tokens + (chemotherapy, hormone therapy, radiotherapy). + Output: + - os_months: regression label (float, months). + - os_status: binary label (int 0/1). + + Examples: + >>> from pyhealth.datasets import METABRICDataset + >>> from pyhealth.tasks import METABRICSurvivalOS + >>> dataset = METABRICDataset(root="/path/to/metabric") + >>> samples = dataset.set_task(METABRICSurvivalOS()) + >>> samples[0]['os_months'] + 85.6 + >>> samples[0]['os_status'] + 0 + """ + + task_name: str = "METABRICSurvivalOS" + input_schema: Dict[str, str] = { + "clinical_features": "sequence", + "treatment_features": "sequence", + } + output_schema: Dict[str, str] = { + "os_months": "regression", + "os_status": "binary", + } + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + events = patient.get_events(event_type="metabric") + if not events: + return [] + + event = events[0] + + os_months = _safe_float(getattr(event, "OS_MONTHS", None)) + os_status_raw = _safe_float(getattr(event, "OS_STATUS", None)) + + if os_months is None or os_status_raw is None: + return [] + if os_months < 0: + return [] + + os_status = int(os_status_raw) + if os_status not in (0, 1): + return [] + + clinical_features = _extract_clinical_features(event) + treatment_features = _extract_treatment_features(event) + + # Require at least some clinical features + if not clinical_features: + return [] + + return [ + { + "patient_id": patient.patient_id, + "clinical_features": clinical_features, + "treatment_features": treatment_features, + "os_months": os_months, + "os_status": os_status, + } + ] + + +class METABRICSurvivalRFS(BaseTask): + """Relapse-free survival prediction task for the METABRIC dataset. + + Predicts relapse-free survival (RFS) for breast cancer patients: + + - **rfs_months**: continuous time-to-relapse / censoring time in months + (regression label). + - **rfs_status**: binary event indicator — 1 if relapse or death occurred, + 0 if relapse-free at last follow-up (censored). + + Task Schema: + Input: + - clinical_features: sequence of tokenised clinical attributes. + - treatment_features: sequence of treatment indicator tokens. + Output: + - rfs_months: regression label (float, months). + - rfs_status: binary label (int 0/1). + + Examples: + >>> from pyhealth.datasets import METABRICDataset + >>> from pyhealth.tasks import METABRICSurvivalRFS + >>> dataset = METABRICDataset(root="/path/to/metabric") + >>> samples = dataset.set_task(METABRICSurvivalRFS()) + >>> samples[0]['rfs_months'] + 62.3 + >>> samples[0]['rfs_status'] + 1 + """ + + task_name: str = "METABRICSurvivalRFS" + input_schema: Dict[str, str] = { + "clinical_features": "sequence", + "treatment_features": "sequence", + } + output_schema: Dict[str, str] = { + "rfs_months": "regression", + "rfs_status": "binary", + } + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + events = patient.get_events(event_type="metabric") + if not events: + return [] + + event = events[0] + + rfs_months = _safe_float(getattr(event, "RFS_MONTHS", None)) + rfs_status_raw = _safe_float(getattr(event, "RFS_STATUS", None)) + + if rfs_months is None or rfs_status_raw is None: + return [] + if rfs_months < 0: + return [] + + rfs_status = int(rfs_status_raw) + if rfs_status not in (0, 1): + return [] + + clinical_features = _extract_clinical_features(event) + treatment_features = _extract_treatment_features(event) + + if not clinical_features: + return [] + + return [ + { + "patient_id": patient.patient_id, + "clinical_features": clinical_features, + "treatment_features": treatment_features, + "rfs_months": rfs_months, + "rfs_status": rfs_status, + } + ] From dcb58af708cdd3fe3213049ff1d6e6ba3bd52c9b Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 5 Apr 2026 16:28:05 -0500 Subject: [PATCH 8/9] seer dataset + task --- pyhealth/datasets/__init__.py | 2 + pyhealth/datasets/configs/seer.yaml | 25 +++ pyhealth/datasets/seer.py | 266 ++++++++++++++++++++++++++++ pyhealth/metrics/__init__.py | 10 ++ pyhealth/tasks/__init__.py | 7 + pyhealth/tasks/seer_survival.py | 236 ++++++++++++++++++++++++ 6 files changed, 546 insertions(+) create mode 100644 pyhealth/datasets/configs/seer.yaml create mode 100644 pyhealth/datasets/seer.py create mode 100644 pyhealth/tasks/seer_survival.py diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..6e976286d 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -65,6 +65,8 @@ def __init__(self, *args, **kwargs): from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset +from .metabric import METABRICDataset +from .seer import SEERDataset from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset from .splitter import ( diff --git a/pyhealth/datasets/configs/seer.yaml b/pyhealth/datasets/configs/seer.yaml new file mode 100644 index 000000000..2a25525ae --- /dev/null +++ b/pyhealth/datasets/configs/seer.yaml @@ -0,0 +1,25 @@ +version: "1.0" +tables: + seer: + file_path: "seer_clinical.csv" + patient_id: "PATIENT_ID" + timestamp: null + attributes: + - "AGE_AT_DIAGNOSIS" + - "SEX" + - "RACE" + - "PRIMARY_SITE" + - "HISTOLOGY" + - "STAGE" + - "GRADE" + - "TUMOR_SIZE_MM" + - "REGIONAL_NODES_EXAMINED" + - "REGIONAL_NODES_POSITIVE" + - "SURVIVAL_MONTHS" + - "VITAL_STATUS" + - "YEAR_OF_DIAGNOSIS" + - "SEQUENCE_NUMBER" + - "LATERALITY" + - "SURGERY" + - "RADIATION" + - "CHEMOTHERAPY" diff --git a/pyhealth/datasets/seer.py b/pyhealth/datasets/seer.py new file mode 100644 index 000000000..cd2250fd1 --- /dev/null +++ b/pyhealth/datasets/seer.py @@ -0,0 +1,266 @@ +"""SEER (Surveillance, Epidemiology, and End Results) dataset for PyHealth. + +The SEER program of the National Cancer Institute (NCI) collects cancer +incidence and survival data from population-based cancer registries covering +approximately 48% of the US population. + +Data access +----------- +SEER data requires a free research data agreement: + https://seer.cancer.gov/data/access.html + +Once approved, download a cohort as a CSV export from SEER*Stat. + +The pre-processed file ``seer_clinical.csv`` expected by this class should +have the columns listed in ``configs/seer.yaml``. +:func:`SEERDataset.prepare_metadata` converts a standard SEER*Stat CSV export +to this format when the processed file is absent. + +Common SEER*Stat export columns (variable labels vary by release; see the +SEER*Stat dictionary for your download): + - ``Patient ID`` + - ``Age recode with single ages and 85+`` (or ``Age at Diagnosis``) + - ``Sex`` + - ``Race recode (W, B, AI, API)`` + - ``Primary Site`` + - ``Histologic Type ICD-O-3`` + - ``Derived AJCC Stage Group, 7th ed (2010-2015)`` + - ``Grade`` + - ``CS tumor size (2004-2015)`` + - ``Regional nodes examined (1988+)`` + - ``Regional nodes positive (1988+)`` + - ``Survival months`` + - ``Vital status recode (study cutoff used)`` (Alive / Dead) + - ``Year of diagnosis`` + +Citation: + National Cancer Institute, DCCPS, Surveillance Research Program, + SEER*Stat software (www.seer.cancer.gov/seerstat). +""" + +import logging +import os +from pathlib import Path +from typing import List, Optional + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class SEERDataset(BaseDataset): + """SEER cancer incidence / survival dataset. + + Each row represents a single tumour record. Patients with multiple + primaries will have multiple rows (distinguished by ``SEQUENCE_NUMBER``). + The dataset is loaded with ``PATIENT_ID`` as the patient identifier. + + Args: + root: Directory containing ``seer_clinical.csv`` (or a raw SEER*Stat + export named ``seer_raw.csv`` / ``seer.csv``). + tables: Additional tables beyond the default ``["seer"]``. + dataset_name: Optional name; defaults to ``"seer"``. + config_path: Optional YAML path; defaults to the bundled + ``configs/seer.yaml``. + **kwargs: Passed through to :class:`~pyhealth.datasets.BaseDataset`. + + Examples: + >>> from pyhealth.datasets import SEERDataset + >>> dataset = SEERDataset(root="/path/to/seer") + >>> dataset.stats() + >>> samples = dataset.set_task(task) + """ + + def __init__( + self, + root: str, + tables: Optional[List[str]] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + if config_path is None: + config_path = Path(__file__).parent / "configs" / "seer.yaml" + + processed_csv = os.path.join(root, "seer_clinical.csv") + if not os.path.exists(processed_csv): + logger.info( + "seer_clinical.csv not found — attempting to prepare from raw data." + ) + self.prepare_metadata(root) + + default_tables = ["seer"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "seer", + config_path=str(config_path), + **kwargs, + ) + + # ------------------------------------------------------------------ + # Data preparation + # ------------------------------------------------------------------ + + @staticmethod + def prepare_metadata(root: str) -> None: + """Convert a raw SEER*Stat CSV export to the standardised format. + + Looks for ``seer_raw.csv``, ``seer.csv``, or ``*.csv`` (first match) + inside ``root`` and writes ``seer_clinical.csv``. + + The mapping below covers the most common SEER*Stat export variable + names. If your export uses different labels, rename the columns in + the output CSV or pass a custom ``config_path`` to :class:`SEERDataset`. + + Args: + root: Directory to search and write output. + """ + raw_candidates = ["seer_raw.csv", "seer.csv"] + raw_file: Optional[str] = None + for fname in raw_candidates: + candidate = os.path.join(root, fname) + if os.path.exists(candidate): + raw_file = candidate + break + + # Fall back: any CSV in root + if raw_file is None: + for fname in os.listdir(root): + if fname.endswith(".csv") and fname != "seer_clinical.csv": + raw_file = os.path.join(root, fname) + logger.info(f"Using fallback raw file: {raw_file}") + break + + output_path = os.path.join(root, "seer_clinical.csv") + + if raw_file is None: + logger.warning( + f"No raw SEER file found in {root}. " + "Please export a cohort from SEER*Stat as a CSV and save it " + "as 'seer_raw.csv' in the root directory. " + "See https://seer.cancer.gov/seerstat for instructions." + ) + pd.DataFrame( + columns=[ + "PATIENT_ID", + "AGE_AT_DIAGNOSIS", + "SEX", + "RACE", + "PRIMARY_SITE", + "HISTOLOGY", + "STAGE", + "GRADE", + "TUMOR_SIZE_MM", + "REGIONAL_NODES_EXAMINED", + "REGIONAL_NODES_POSITIVE", + "SURVIVAL_MONTHS", + "VITAL_STATUS", + "YEAR_OF_DIAGNOSIS", + "SEQUENCE_NUMBER", + "LATERALITY", + "SURGERY", + "RADIATION", + "CHEMOTHERAPY", + ] + ).to_csv(output_path, index=False) + return + + logger.info(f"Processing SEER raw file: {raw_file}") + df = pd.read_csv(raw_file, low_memory=False) + + # --- column name normalisation --- + # SEER*Stat uses verbose labels; map common variants to short names. + rename: dict = {} + for col in df.columns: + col_upper = col.upper().strip() + if "PATIENT ID" in col_upper or col_upper == "PATIENT_ID": + rename[col] = "PATIENT_ID" + elif "AGE" in col_upper and "DIAGNOSIS" in col_upper: + rename[col] = "AGE_AT_DIAGNOSIS" + elif col_upper in ("SEX", "GENDER"): + rename[col] = "SEX" + elif "RACE" in col_upper and "RECODE" not in col_upper: + rename[col] = "RACE" + elif "RACE RECODE" in col_upper: + rename[col] = "RACE" + elif "PRIMARY SITE" in col_upper or "PRIMARY_SITE" in col_upper: + rename[col] = "PRIMARY_SITE" + elif "HISTOLOGIC TYPE" in col_upper or "HISTOLOGY" in col_upper: + rename[col] = "HISTOLOGY" + elif "STAGE" in col_upper and "AJCC" not in col_upper: + rename[col] = "STAGE" + elif "DERIVED AJCC STAGE" in col_upper: + rename[col] = "STAGE" + elif col_upper == "GRADE": + rename[col] = "GRADE" + elif "CS TUMOR SIZE" in col_upper or "TUMOR SIZE" in col_upper: + rename[col] = "TUMOR_SIZE_MM" + elif "REGIONAL NODES EXAMINED" in col_upper: + rename[col] = "REGIONAL_NODES_EXAMINED" + elif "REGIONAL NODES POSITIVE" in col_upper: + rename[col] = "REGIONAL_NODES_POSITIVE" + elif "SURVIVAL MONTHS" in col_upper: + rename[col] = "SURVIVAL_MONTHS" + elif "VITAL STATUS" in col_upper: + rename[col] = "VITAL_STATUS" + elif "YEAR OF DIAGNOSIS" in col_upper: + rename[col] = "YEAR_OF_DIAGNOSIS" + elif "SEQUENCE NUMBER" in col_upper: + rename[col] = "SEQUENCE_NUMBER" + elif "LATERALITY" in col_upper: + rename[col] = "LATERALITY" + elif "SURGERY" in col_upper: + rename[col] = "SURGERY" + elif "RADIATION" in col_upper: + rename[col] = "RADIATION" + elif "CHEMO" in col_upper: + rename[col] = "CHEMOTHERAPY" + + df = df.rename(columns=rename) + + # SEER vital status: "Alive" → 0, "Dead" → 1 + if "VITAL_STATUS" in df.columns: + df["VITAL_STATUS"] = ( + df["VITAL_STATUS"] + .astype(str) + .str.strip() + .str.upper() + .map({"ALIVE": 0, "DEAD": 1, "0": 0, "1": 1}) + .fillna(df["VITAL_STATUS"]) + ) + + if "PATIENT_ID" not in df.columns: + df["PATIENT_ID"] = df.index.astype(str) + + # Select output columns (use those present) + desired_cols = [ + "PATIENT_ID", + "AGE_AT_DIAGNOSIS", + "SEX", + "RACE", + "PRIMARY_SITE", + "HISTOLOGY", + "STAGE", + "GRADE", + "TUMOR_SIZE_MM", + "REGIONAL_NODES_EXAMINED", + "REGIONAL_NODES_POSITIVE", + "SURVIVAL_MONTHS", + "VITAL_STATUS", + "YEAR_OF_DIAGNOSIS", + "SEQUENCE_NUMBER", + "LATERALITY", + "SURGERY", + "RADIATION", + "CHEMOTHERAPY", + ] + present_cols = [c for c in desired_cols if c in df.columns] + df_out = df[present_cols] + + df_out.to_csv(output_path, index=False) + logger.info(f"Saved {len(df_out)} SEER records to {output_path}") diff --git a/pyhealth/metrics/__init__.py b/pyhealth/metrics/__init__.py index da8da0f5b..11c0ed040 100644 --- a/pyhealth/metrics/__init__.py +++ b/pyhealth/metrics/__init__.py @@ -13,6 +13,12 @@ # from .fairness import fairness_metrics_fn from .ranking import ranking_metrics_fn from .regression import regression_metrics_fn +from .survival import ( + concordance_index_censored, + brier_score_survival, + integrated_brier_score, + survival_metrics_fn, +) __all__ = [ "binary_metrics_fn", @@ -26,4 +32,8 @@ "multilabel_metrics_fn", "ranking_metrics_fn", "regression_metrics_fn", + "concordance_index_censored", + "brier_score_survival", + "integrated_brier_score", + "survival_metrics_fn", ] diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..781e4a8f0 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -22,11 +22,16 @@ drug_recommendation_omop_fn, ) from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 +from .ckd_surv import MIMIC4CKDSurvAnalysis from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, LengthOfStayPredictionMIMIC3, LengthOfStayPredictionMIMIC4, LengthOfStayPredictionOMOP, + LengthOfStayRegressioneICU, + LengthOfStayRegressionMIMIC3, + LengthOfStayRegressionMIMIC4, + LengthOfStayRegressionOMOP, ) from .length_of_stay_stagenet_mimic4 import LengthOfStayStageNetMIMIC4 from .medical_coding import MIMIC3ICD9Coding @@ -40,6 +45,8 @@ MultimodalMortalityPredictionMIMIC3, MultimodalMortalityPredictionMIMIC4, ) +from .metabric_survival import METABRICSurvivalOS, METABRICSurvivalRFS +from .seer_survival import SEERSurvivalTask from .survival_preprocess_support2 import SurvivalPreprocessSupport2 from .mortality_prediction_stagenet_mimic4 import ( MortalityPredictionStageNetMIMIC4, diff --git a/pyhealth/tasks/seer_survival.py b/pyhealth/tasks/seer_survival.py new file mode 100644 index 000000000..5103805c2 --- /dev/null +++ b/pyhealth/tasks/seer_survival.py @@ -0,0 +1,236 @@ +"""Survival analysis task for the SEER cancer registry dataset. + +The task predicts overall survival (time in months + event indicator) for +cancer patients in the SEER dataset. + +Usage +----- +>>> from pyhealth.datasets import SEERDataset +>>> from pyhealth.tasks import SEERSurvivalTask +>>> +>>> dataset = SEERDataset(root="/path/to/seer") +>>> task = SEERSurvivalTask() +>>> samples = dataset.set_task(task) +>>> samples[0].keys() +dict_keys(['patient_id', 'patient_features', 'tumour_features', + 'treatment_features', 'survival_months', 'vital_status']) +""" + +import logging +from typing import Any, Dict, List, Optional + +from .base_task import BaseTask + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _safe_float(value: Any) -> Optional[float]: + if value is None: + return None + try: + return float(value) + except (ValueError, TypeError): + return None + + +def _safe_str(value: Any) -> Optional[str]: + if value is None: + return None + s = str(value).strip() + return s if s and s.lower() not in ("nan", "none", "unknown", "") else None + + +def _age_bucket(age: float) -> str: + if age < 30: + return "under30" + elif age < 45: + return "30_44" + elif age < 60: + return "45_59" + elif age < 75: + return "60_74" + else: + return "75plus" + + +def _extract_patient_features(event: Any) -> List[str]: + """Demographic and patient-level tokens.""" + features: List[str] = [] + + age = _safe_float(getattr(event, "AGE_AT_DIAGNOSIS", None)) + if age is not None: + features.append(f"age_{_age_bucket(age)}") + + sex = _safe_str(getattr(event, "SEX", None)) + if sex: + features.append(f"sex_{sex.lower()}") + + race = _safe_str(getattr(event, "RACE", None)) + if race: + features.append(f"race_{race.lower()}") + + year = _safe_float(getattr(event, "YEAR_OF_DIAGNOSIS", None)) + if year is not None: + # Decade bucket + decade = int(year // 10) * 10 + features.append(f"diagnosis_decade_{decade}") + + return features + + +def _extract_tumour_features(event: Any) -> List[str]: + """Tumour characteristics tokens.""" + features: List[str] = [] + + primary_site = _safe_str(getattr(event, "PRIMARY_SITE", None)) + if primary_site: + features.append(f"site_{primary_site.lower()}") + + histology = _safe_str(getattr(event, "HISTOLOGY", None)) + if histology: + features.append(f"hist_{histology.lower()}") + + stage = _safe_str(getattr(event, "STAGE", None)) + if stage: + features.append(f"stage_{stage.lower()}") + + grade = _safe_str(getattr(event, "GRADE", None)) + if grade: + features.append(f"grade_{grade.lower()}") + + laterality = _safe_str(getattr(event, "LATERALITY", None)) + if laterality: + features.append(f"lat_{laterality.lower()}") + + tumor_size = _safe_float(getattr(event, "TUMOR_SIZE_MM", None)) + if tumor_size is not None and tumor_size >= 0: + if tumor_size == 0: + size_cat = "in_situ" + elif tumor_size <= 10: + size_cat = "le10mm" + elif tumor_size <= 20: + size_cat = "le20mm" + elif tumor_size <= 50: + size_cat = "le50mm" + else: + size_cat = "gt50mm" + features.append(f"tumor_size_{size_cat}") + + nodes_pos = _safe_float(getattr(event, "REGIONAL_NODES_POSITIVE", None)) + if nodes_pos is not None and nodes_pos >= 0: + node_cat = "none" if nodes_pos == 0 else ("low" if nodes_pos <= 3 else "high") + features.append(f"nodes_pos_{node_cat}") + + return features + + +def _extract_treatment_features(event: Any) -> List[str]: + """Treatment indicator tokens.""" + features: List[str] = [] + for field in ("SURGERY", "RADIATION", "CHEMOTHERAPY"): + val = _safe_str(getattr(event, field, None)) + if val: + features.append(f"{field.lower()}_{val.lower()}") + return features + + +# --------------------------------------------------------------------------- +# Task class +# --------------------------------------------------------------------------- + +class SEERSurvivalTask(BaseTask): + """Overall survival prediction task for the SEER cancer registry. + + Predicts overall survival for cancer patients using SEER clinical data: + + - **survival_months**: continuous time-to-death / censoring time in months + (regression label). + - **vital_status**: binary event indicator — 1 if the patient died (within + the study window), 0 if alive at last follow-up (censored). + + Task Schema: + Input: + - patient_features: sequence of demographic / patient-level tokens + (age bucket, sex, race, diagnosis decade). + - tumour_features: sequence of tumour characteristic tokens + (primary site, histology, AJCC stage, grade, tumour size, + node status). + - treatment_features: sequence of treatment indicator tokens + (surgery, radiation, chemotherapy). + Output: + - survival_months: regression label (float, months ≥ 0). + - vital_status: binary label (int 0/1). + + Args: + min_survival_months: Minimum survival time to include (default 0). + Patients with survival_months < this threshold are excluded. + + Examples: + >>> from pyhealth.datasets import SEERDataset + >>> from pyhealth.tasks import SEERSurvivalTask + >>> dataset = SEERDataset(root="/path/to/seer") + >>> samples = dataset.set_task(SEERSurvivalTask()) + >>> samples[0]['survival_months'] + 84.0 + >>> samples[0]['vital_status'] + 0 + """ + + task_name: str = "SEERSurvivalTask" + input_schema: Dict[str, str] = { + "patient_features": "sequence", + "tumour_features": "sequence", + "treatment_features": "sequence", + } + output_schema: Dict[str, str] = { + "survival_months": "regression", + "vital_status": "binary", + } + + def __init__(self, min_survival_months: float = 0.0): + super().__init__() + self.min_survival_months = min_survival_months + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + events = patient.get_events(event_type="seer") + if not events: + return [] + + samples: List[Dict[str, Any]] = [] + for event in events: + survival_months = _safe_float(getattr(event, "SURVIVAL_MONTHS", None)) + vital_status_raw = _safe_float(getattr(event, "VITAL_STATUS", None)) + + if survival_months is None or vital_status_raw is None: + continue + if survival_months < self.min_survival_months: + continue + + vital_status = int(vital_status_raw) + if vital_status not in (0, 1): + continue + + patient_features = _extract_patient_features(event) + tumour_features = _extract_tumour_features(event) + treatment_features = _extract_treatment_features(event) + + # Need at least patient and tumour features to form a useful sample + if not patient_features or not tumour_features: + continue + + samples.append( + { + "patient_id": patient.patient_id, + "patient_features": patient_features, + "tumour_features": tumour_features, + "treatment_features": treatment_features, + "survival_months": survival_months, + "vital_status": vital_status, + } + ) + + return samples From 01b8261c1ce6dd08f1964183db35eae157f49686 Mon Sep 17 00:00:00 2001 From: lehendo Date: Sun, 5 Apr 2026 16:34:22 -0500 Subject: [PATCH 9/9] stylistic fixes --- pyhealth/datasets/seer.py | 11 ----------- pyhealth/tasks/seer_survival.py | 10 ---------- 2 files changed, 21 deletions(-) diff --git a/pyhealth/datasets/seer.py b/pyhealth/datasets/seer.py index cd2250fd1..15b96f000 100644 --- a/pyhealth/datasets/seer.py +++ b/pyhealth/datasets/seer.py @@ -4,8 +4,6 @@ incidence and survival data from population-based cancer registries covering approximately 48% of the US population. -Data access ------------ SEER data requires a free research data agreement: https://seer.cancer.gov/data/access.html @@ -32,10 +30,6 @@ - ``Survival months`` - ``Vital status recode (study cutoff used)`` (Alive / Dead) - ``Year of diagnosis`` - -Citation: - National Cancer Institute, DCCPS, Surveillance Research Program, - SEER*Stat software (www.seer.cancer.gov/seerstat). """ import logging @@ -102,10 +96,6 @@ def __init__( **kwargs, ) - # ------------------------------------------------------------------ - # Data preparation - # ------------------------------------------------------------------ - @staticmethod def prepare_metadata(root: str) -> None: """Convert a raw SEER*Stat CSV export to the standardised format. @@ -173,7 +163,6 @@ def prepare_metadata(root: str) -> None: logger.info(f"Processing SEER raw file: {raw_file}") df = pd.read_csv(raw_file, low_memory=False) - # --- column name normalisation --- # SEER*Stat uses verbose labels; map common variants to short names. rename: dict = {} for col in df.columns: diff --git a/pyhealth/tasks/seer_survival.py b/pyhealth/tasks/seer_survival.py index 5103805c2..e9624eada 100644 --- a/pyhealth/tasks/seer_survival.py +++ b/pyhealth/tasks/seer_survival.py @@ -23,11 +23,6 @@ logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - def _safe_float(value: Any) -> Optional[float]: if value is None: return None @@ -137,11 +132,6 @@ def _extract_treatment_features(event: Any) -> List[str]: features.append(f"{field.lower()}_{val.lower()}") return features - -# --------------------------------------------------------------------------- -# Task class -# --------------------------------------------------------------------------- - class SEERSurvivalTask(BaseTask): """Overall survival prediction task for the SEER cancer registry.