From 30756f71aaa7104a1d91282415d12a367a2304c1 Mon Sep 17 00:00:00 2001 From: John Carson Date: Sat, 28 Mar 2026 12:36:30 -0400 Subject: [PATCH 1/4] Add clinical jargon benchmark dataset and verification task --- docs/api/datasets.rst | 1 + ...yhealth.datasets.ClinicalJargonDataset.rst | 7 + docs/api/tasks.rst | 1 + ...ealth.tasks.ClinicalJargonVerification.rst | 7 + ...inical_jargon_verification_transformers.py | 76 +++++ pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/clinical_jargon.py | 264 ++++++++++++++++++ .../datasets/configs/clinical_jargon.yaml | 20 ++ pyhealth/tasks/__init__.py | 1 + .../tasks/clinical_jargon_verification.py | 78 ++++++ .../clinical_jargon_examples.csv | 9 + tests/core/test_clinical_jargon.py | 79 ++++++ 12 files changed, 544 insertions(+) create mode 100644 docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst create mode 100644 docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst create mode 100644 examples/clinical_jargon_clinical_jargon_verification_transformers.py create mode 100644 pyhealth/datasets/clinical_jargon.py create mode 100644 pyhealth/datasets/configs/clinical_jargon.yaml create mode 100644 pyhealth/tasks/clinical_jargon_verification.py create mode 100644 test-resources/clinical_jargon/clinical_jargon_examples.csv create mode 100644 tests/core/test_clinical_jargon.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..ec0166994 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -240,6 +240,7 @@ Available Datasets datasets/pyhealth.datasets.ChestXray14Dataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset + datasets/pyhealth.datasets.ClinicalJargonDataset datasets/pyhealth.datasets.ClinVarDataset datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset diff --git a/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst b/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst new file mode 100644 index 000000000..c82341227 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst @@ -0,0 +1,7 @@ +pyhealth.datasets.ClinicalJargonDataset +======================================= + +.. autoclass:: pyhealth.datasets.ClinicalJargonDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index d85d04bc3..9cf0ac729 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -227,6 +227,7 @@ Available Tasks Benchmark EHRShot ChestX-ray14 Binary Classification ChestX-ray14 Multilabel Classification + Clinical Jargon Verification Variant Classification (ClinVar) Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) diff --git a/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst b/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst new file mode 100644 index 000000000..af6c5c186 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.ClinicalJargonVerification +========================================= + +.. autoclass:: pyhealth.tasks.ClinicalJargonVerification + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/clinical_jargon_clinical_jargon_verification_transformers.py b/examples/clinical_jargon_clinical_jargon_verification_transformers.py new file mode 100644 index 000000000..6547d7c9b --- /dev/null +++ b/examples/clinical_jargon_clinical_jargon_verification_transformers.py @@ -0,0 +1,76 @@ +import argparse +import sys +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from pyhealth.datasets import ClinicalJargonDataset, get_dataloader, split_by_sample +from pyhealth.models.transformers_model import TransformersModel +from pyhealth.tasks import ClinicalJargonVerification +from pyhealth.trainer import Trainer + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Clinical jargon verification example with TransformersModel." + ) + parser.add_argument( + "--root", + type=str, + default=str( + PROJECT_ROOT / "test-resources" / "clinical_jargon" + ), + ) + parser.add_argument( + "--model-name", + type=str, + default="emilyalsentzer/Bio_ClinicalBERT", + ) + parser.add_argument( + "--benchmark", + choices=["all", "medlingo", "casi"], + default="medlingo", + ) + parser.add_argument( + "--casi-variant", + choices=["release62", "paper59"], + default="release62", + ) + parser.add_argument("--medlingo-distractors", type=int, default=3) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--batch-size", type=int, default=4) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + dataset = ClinicalJargonDataset(root=args.root) + task = ClinicalJargonVerification( + benchmark=args.benchmark, + casi_variant=args.casi_variant, + medlingo_distractors=args.medlingo_distractors, + ) + samples = dataset.set_task(task) + train_dataset, val_dataset, test_dataset = split_by_sample( + samples, [0.6, 0.2, 0.2], seed=42 + ) + train_loader = get_dataloader(train_dataset, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=args.batch_size, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=args.batch_size, shuffle=False) + + model = TransformersModel(dataset=samples, model_name=args.model_name) + trainer = Trainer(model=model, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + test_dataloader=test_loader, + epochs=args.epochs, + ) + scores = trainer.evaluate(test_loader) + print(scores) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 3edbe06f7..a679a2317 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -65,6 +65,7 @@ def __init__(self, *args, **kwargs): from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset +from .clinical_jargon import ClinicalJargonDataset from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset from .splitter import ( diff --git a/pyhealth/datasets/clinical_jargon.py b/pyhealth/datasets/clinical_jargon.py new file mode 100644 index 000000000..21bf793d9 --- /dev/null +++ b/pyhealth/datasets/clinical_jargon.py @@ -0,0 +1,264 @@ +import csv +import json +import re +import urllib.request +from collections import defaultdict +from pathlib import Path +from typing import Optional + +from .base_dataset import BaseDataset + + +MEDLINGO_URL = ( + "https://raw.githubusercontent.com/Flora-jia-jfr/diagnosing_our_datasets/" + "main/datasets/MedLingo/questions.csv" +) +CASI_RELEASE_INDEX_URL = ( + "https://api.github.com/repos/Flora-jia-jfr/diagnosing_our_datasets/contents/" + "datasets/casi/cleaned_dataset_subset?ref=main" +) +MEDLINGO_ONESHOT_PREFIX = ( + "In a clinical note that mentions a high creat, creat stands for creatine. " +) +PAPER59_EXCLUSIONS = frozenset( + { + ("AB", "blood group in ABO system"), + ("US", "United States"), + ("IB", "international baccalaureate"), + ("MS", "master of science"), + ("MP", "military police"), + ("PD", "police department"), + ("MP", "metatarsophalangeal/metacarpophalangeal"), + ("OP", "oblique presentation/occiput posterior"), + ("SA", "slow acting/sustained action"), + ("C&S", "conjunctivae and sclerae"), + ("C&S", "culture and sensitivity"), + ("C&S", "protein C and protein S"), + } +) + + +def split_aliases(answer: str) -> list[str]: + pieces = re.split(r"\s+or\s+", answer.strip()) + aliases = [piece.strip() for piece in pieces if piece.strip()] + return aliases or [answer.strip()] + + +def surface_form_group(abbreviation: str) -> str: + if any(character.isdigit() or not character.isalpha() for character in abbreviation): + return "digit_or_symbol" + if abbreviation.isupper(): + return "all_caps" + if abbreviation.islower(): + return "lowercase" + return "mixed_case" + + +def strip_medlingo_oneshot(question: str) -> str: + if question.startswith(MEDLINGO_ONESHOT_PREFIX): + return question[len(MEDLINGO_ONESHOT_PREFIX) :].strip() + return question.strip() + + +def dedupe(values: list[str]) -> list[str]: + seen: set[str] = set() + ordered: list[str] = [] + for value in values: + if value not in seen: + seen.add(value) + ordered.append(value) + return ordered + + +def token_length(text: str) -> int: + return len(re.findall(r"\w+", text)) + + +def choose_medlingo_distractors( + records: list[dict], + current_record: dict, + distractor_count: int = 3, +) -> list[str]: + gold = current_record["gold_expansion"] + goal_length = token_length(gold) + + def rank(pool: list[str]) -> list[str]: + return sorted( + dedupe([value for value in pool if value != gold]), + key=lambda value: (abs(token_length(value) - goal_length), value.lower()), + ) + + same_group = [ + record["gold_expansion"] + for record in records + if record["sample_id"] != current_record["sample_id"] + and record["surface_form_group"] == current_record["surface_form_group"] + ] + global_pool = [ + record["gold_expansion"] + for record in records + if record["sample_id"] != current_record["sample_id"] + ] + negatives = rank(same_group) + if len(negatives) < distractor_count: + for candidate in rank(global_pool): + if candidate not in negatives: + negatives.append(candidate) + if len(negatives) == distractor_count: + break + return negatives[:distractor_count] + + +class ClinicalJargonDataset(BaseDataset): + """Public clinical jargon benchmark dataset for PyHealth.""" + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + root_path = Path(root) + root_path.mkdir(parents=True, exist_ok=True) + if config_path is None: + config_path = Path(__file__).parent / "configs" / "clinical_jargon.yaml" + normalized_csv = root_path / "clinical_jargon_examples.csv" + if not normalized_csv.exists(): + self.prepare_metadata(root_path) + super().__init__( + root=str(root_path), + tables=["examples"], + dataset_name=dataset_name or "clinical_jargon", + config_path=str(config_path), + **kwargs, + ) + + @staticmethod + def _download_text(url: str, destination: Path) -> str: + if destination.exists(): + return destination.read_text() + payload = urllib.request.urlopen(url).read().decode("utf-8", errors="replace") + destination.write_text(payload) + return payload + + @classmethod + def _fetch_medlingo_rows(cls, cache_dir: Path) -> list[dict]: + csv_text = cls._download_text(MEDLINGO_URL, cache_dir / "medlingo_questions.csv") + return list(csv.DictReader(csv_text.splitlines())) + + @classmethod + def _fetch_casi_rows(cls, cache_dir: Path) -> list[dict]: + index_path = cache_dir / "casi_release_index.json" + if index_path.exists(): + entries = json.loads(index_path.read_text()) + else: + entries = json.loads( + urllib.request.urlopen(CASI_RELEASE_INDEX_URL) + .read() + .decode("utf-8", errors="replace") + ) + index_path.write_text(json.dumps(entries, indent=2)) + rows: list[dict] = [] + for entry in entries: + file_name = entry["name"] + file_text = cls._download_text(entry["download_url"], cache_dir / file_name) + for row in csv.DictReader(file_text.splitlines()): + row["source_file"] = file_name + rows.append(row) + return rows + + @classmethod + def prepare_metadata(cls, root: Path) -> None: + cache_dir = root / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + + medlingo_rows: list[dict] = [] + for index, row in enumerate(cls._fetch_medlingo_rows(cache_dir), start=1): + aliases = split_aliases(row["answer"]) + medlingo_rows.append( + { + "sample_id": f"medlingo_{index:03d}", + "benchmark": "medlingo", + "abbreviation": row["word1"], + "context": "", + "question": row["question"].strip(), + "question_zero_shot": strip_medlingo_oneshot(row["question"]), + "gold_expansion": aliases[0], + "gold_aliases_json": json.dumps(aliases), + "surface_form_group": surface_form_group(row["word1"]), + "paper59_included": "true", + "source_file": "medlingo_questions.csv", + } + ) + + medlingo_candidate_map = {} + for row in medlingo_rows: + negatives = choose_medlingo_distractors(medlingo_rows, row, distractor_count=3) + medlingo_candidate_map[row["sample_id"]] = [row["gold_expansion"], *negatives] + + for row in medlingo_rows: + row["candidate_expansions_json"] = json.dumps( + medlingo_candidate_map[row["sample_id"]] + ) + row["candidate_expansions_paper59_json"] = row["candidate_expansions_json"] + + casi_rows = cls._fetch_casi_rows(cache_dir) + release_expansions: dict[str, list[str]] = defaultdict(list) + paper59_expansions: dict[str, list[str]] = defaultdict(list) + for row in casi_rows: + abbreviation = row["sf"] + expansion = row["target_lf"] + release_expansions[abbreviation].append(expansion) + if (abbreviation, expansion) not in PAPER59_EXCLUSIONS: + paper59_expansions[abbreviation].append(expansion) + + release_expansions = { + key: dedupe(sorted(values, key=str.lower)) + for key, values in release_expansions.items() + } + paper59_expansions = { + key: dedupe(sorted(values, key=str.lower)) + for key, values in paper59_expansions.items() + } + + normalized_rows = list(medlingo_rows) + for index, row in enumerate(casi_rows, start=1): + abbreviation = row["sf"] + expansion = row["target_lf"] + question = f"{row['context'].strip()} In this sentence, {abbreviation} means:" + normalized_rows.append( + { + "sample_id": f"casi_{index:04d}", + "benchmark": "casi", + "abbreviation": abbreviation, + "context": row["context"].strip(), + "question": question, + "question_zero_shot": question, + "gold_expansion": expansion, + "gold_aliases_json": json.dumps([expansion]), + "surface_form_group": surface_form_group(abbreviation), + "paper59_included": str( + (abbreviation, expansion) not in PAPER59_EXCLUSIONS + ).lower(), + "source_file": row["source_file"], + "candidate_expansions_json": json.dumps( + release_expansions[abbreviation] + ), + "candidate_expansions_paper59_json": json.dumps( + paper59_expansions.get(abbreviation, []) + ), + } + ) + + output_path = root / "clinical_jargon_examples.csv" + with output_path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(normalized_rows[0].keys())) + writer.writeheader() + writer.writerows(normalized_rows) + + @property + def default_task(self): + from ..tasks import ClinicalJargonVerification + + return ClinicalJargonVerification() diff --git a/pyhealth/datasets/configs/clinical_jargon.yaml b/pyhealth/datasets/configs/clinical_jargon.yaml new file mode 100644 index 000000000..9592c5114 --- /dev/null +++ b/pyhealth/datasets/configs/clinical_jargon.yaml @@ -0,0 +1,20 @@ +version: "1.0" +tables: + examples: + file_path: "clinical_jargon_examples.csv" + patient_id: null + timestamp: null + attributes: + - "sample_id" + - "benchmark" + - "abbreviation" + - "context" + - "question" + - "question_zero_shot" + - "gold_expansion" + - "gold_aliases_json" + - "surface_form_group" + - "paper59_included" + - "source_file" + - "candidate_expansions_json" + - "candidate_expansions_paper59_json" diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..e90507df7 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -11,6 +11,7 @@ ) from .chestxray14_binary_classification import ChestXray14BinaryClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification +from .clinical_jargon_verification import ClinicalJargonVerification from .covid19_cxr_classification import COVID19CXRClassification from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 from .drug_recommendation import ( diff --git a/pyhealth/tasks/clinical_jargon_verification.py b/pyhealth/tasks/clinical_jargon_verification.py new file mode 100644 index 000000000..19fdc487b --- /dev/null +++ b/pyhealth/tasks/clinical_jargon_verification.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, List +import json + +from ..data import Patient +from .base_task import BaseTask + + +def parse_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + return str(value).strip().lower() == "true" + + +class ClinicalJargonVerification(BaseTask): + """Binary candidate-verification task for public clinical jargon benchmarks.""" + + task_name: str = "ClinicalJargonVerification" + input_schema: Dict[str, str] = {"paired_text": "text"} + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__( + self, + benchmark: str = "all", + casi_variant: str = "release62", + medlingo_distractors: int = 3, + ) -> None: + if benchmark not in {"all", "medlingo", "casi"}: + raise ValueError(f"Unsupported benchmark: {benchmark}") + if casi_variant not in {"release62", "paper59"}: + raise ValueError(f"Unsupported CASI variant: {casi_variant}") + if medlingo_distractors not in {1, 2, 3}: + raise ValueError("medlingo_distractors must be 1, 2, or 3") + self.benchmark = benchmark + self.casi_variant = casi_variant + self.medlingo_distractors = medlingo_distractors + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + events = patient.get_events(event_type="examples") + if len(events) != 1: + return [] + event = events[0] + + if self.benchmark != "all" and event.benchmark != self.benchmark: + return [] + if event.benchmark == "casi" and self.casi_variant == "paper59": + if not parse_bool(event.paper59_included): + return [] + + if event.benchmark == "casi" and self.casi_variant == "paper59": + candidates = json.loads(event.candidate_expansions_paper59_json) + else: + candidates = json.loads(event.candidate_expansions_json) + + if event.benchmark == "medlingo": + gold = event.gold_expansion + negatives = [candidate for candidate in candidates if candidate != gold] + candidates = [gold, *negatives[: self.medlingo_distractors]] + + samples: List[Dict[str, Any]] = [] + for candidate in candidates: + samples.append( + { + "patient_id": patient.patient_id, + "record_id": event.sample_id, + "sample_id": event.sample_id, + "benchmark": event.benchmark, + "abbreviation": event.abbreviation, + "candidate_expansion": candidate, + "surface_form_group": event.surface_form_group, + "paired_text": ( + f"Question: {event.question}\n" + f"Candidate expansion: {candidate}\n" + "Is this the correct expansion?" + ), + "label": int(candidate == event.gold_expansion), + } + ) + return samples diff --git a/test-resources/clinical_jargon/clinical_jargon_examples.csv b/test-resources/clinical_jargon/clinical_jargon_examples.csv new file mode 100644 index 000000000..d2ea2dd18 --- /dev/null +++ b/test-resources/clinical_jargon/clinical_jargon_examples.csv @@ -0,0 +1,9 @@ +sample_id,benchmark,abbreviation,context,question,question_zero_shot,gold_expansion,gold_aliases_json,surface_form_group,paper59_included,source_file,candidate_expansions_json,candidate_expansions_paper59_json +medlingo_001,medlingo,PRN,,"In a clinical note that mentions a high creat, creat stands for creatine. In a clinical note that mentions PRN, PRN stands for","In a clinical note that mentions PRN, PRN stands for",as needed,"[""as needed""]",all_caps,true,medlingo_questions.csv,"[""as needed"", ""vital signs"", ""hypertension"", ""abdomen""]","[""as needed"", ""vital signs"", ""hypertension"", ""abdomen""]" +medlingo_002,medlingo,VS,,"In a clinical note that mentions a high creat, creat stands for creatine. In a clinical note that mentions VS every four hours, VS stands for","In a clinical note that mentions VS every four hours, VS stands for",vital signs,"[""vital signs""]",all_caps,true,medlingo_questions.csv,"[""vital signs"", ""as needed"", ""hypertension"", ""abdomen""]","[""vital signs"", ""as needed"", ""hypertension"", ""abdomen""]" +casi_0001,casi,AB,"Elective AB in 1989.","Elective AB in 1989. In this sentence, AB means:","Elective AB in 1989. In this sentence, AB means:",abortion,"[""abortion""]",all_caps,true,cleaned_dataset_AB.csv,"[""abortion"", ""blood group in ABO system""]","[""abortion""]" +casi_0002,casi,AB,"Type and screen: AB positive.","Type and screen: AB positive. In this sentence, AB means:","Type and screen: AB positive. In this sentence, AB means:","blood group in ABO system","[""blood group in ABO system""]",all_caps,false,cleaned_dataset_AB.csv,"[""abortion"", ""blood group in ABO system""]","[""abortion""]" +casi_0003,casi,US,"US abdomen ordered.","US abdomen ordered. In this sentence, US means:","US abdomen ordered. In this sentence, US means:",ultrasound,"[""ultrasound""]",all_caps,true,cleaned_dataset_US.csv,"[""United States"", ""ultrasound""]","[""ultrasound""]" +casi_0004,casi,US,"Travel history from the US noted.","Travel history from the US noted. In this sentence, US means:","Travel history from the US noted. In this sentence, US means:","United States","[""United States""]",all_caps,false,cleaned_dataset_US.csv,"[""United States"", ""ultrasound""]","[""ultrasound""]" +casi_0005,casi,CA,"History notable for CA of the colon.","History notable for CA of the colon. In this sentence, CA means:","History notable for CA of the colon. In this sentence, CA means:",cancer,"[""cancer""]",all_caps,true,cleaned_dataset_CA.csv,"[""cancer"", ""carbohydrate antigen""]","[""cancer"", ""carbohydrate antigen""]" +casi_0006,casi,CA,"Lab review included an elevated CA 19-9.","Lab review included an elevated CA 19-9. In this sentence, CA means:","Lab review included an elevated CA 19-9. In this sentence, CA means:","carbohydrate antigen","[""carbohydrate antigen""]",all_caps,true,cleaned_dataset_CA.csv,"[""cancer"", ""carbohydrate antigen""]","[""cancer"", ""carbohydrate antigen""]" diff --git a/tests/core/test_clinical_jargon.py b/tests/core/test_clinical_jargon.py new file mode 100644 index 000000000..c7cea7f3e --- /dev/null +++ b/tests/core/test_clinical_jargon.py @@ -0,0 +1,79 @@ +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +from pyhealth.datasets import ClinicalJargonDataset +from pyhealth.tasks import ClinicalJargonVerification + + +class TestClinicalJargonDataset(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.root = ( + Path(__file__).parent.parent.parent + / "test-resources" + / "clinical_jargon" + ) + cls.cache_dir = TemporaryDirectory() + + @classmethod + def tearDownClass(cls): + cls.cache_dir.cleanup() + + def make_dataset(self): + return ClinicalJargonDataset( + root=str(self.root), + cache_dir=self.cache_dir.name, + ) + + def test_dataset_initialization(self): + dataset = self.make_dataset() + self.assertEqual(dataset.dataset_name, "clinical_jargon") + + def test_default_task(self): + dataset = self.make_dataset() + self.assertIsInstance(dataset.default_task, ClinicalJargonVerification) + + def test_num_patients(self): + dataset = self.make_dataset() + self.assertEqual(len(dataset.unique_patient_ids), 8) + + def test_get_patient(self): + dataset = self.make_dataset() + patient = dataset.get_patient(dataset.unique_patient_ids[0]) + self.assertIsNotNone(patient) + self.assertEqual(len(patient.get_events("examples")), 1) + + def test_release62_task_samples(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification(benchmark="all", casi_variant="release62") + samples = dataset.set_task(task) + self.assertEqual(len(samples), 20) + + def test_paper59_task_samples(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification(benchmark="all", casi_variant="paper59") + samples = dataset.set_task(task) + self.assertEqual(len(samples), 14) + + def test_medlingo_distractor_control(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification( + benchmark="medlingo", + medlingo_distractors=1, + ) + samples = dataset.set_task(task) + self.assertEqual(len(samples), 4) + + def test_task_output_shape(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification(benchmark="casi", casi_variant="paper59") + samples = dataset.set_task(task) + sample = samples[0] + self.assertIn("paired_text", sample) + self.assertIn("label", sample) + self.assertIn("candidate_expansion", sample) + + +if __name__ == "__main__": + unittest.main() From fe983a3da81779caf9f52a5eee6b7145bba9795c Mon Sep 17 00:00:00 2001 From: John Carson Date: Tue, 7 Apr 2026 20:02:39 -0400 Subject: [PATCH 2/4] Polish clinical jargon docs and example guidance --- ...yhealth.datasets.ClinicalJargonDataset.rst | 14 ++ ...ealth.tasks.ClinicalJargonVerification.rst | 18 +++ ...inical_jargon_verification_transformers.py | 56 +++++++- pyhealth/datasets/clinical_jargon.py | 133 +++++++++++++++++- .../tasks/clinical_jargon_verification.py | 58 +++++++- 5 files changed, 272 insertions(+), 7 deletions(-) diff --git a/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst b/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst index c82341227..2f8e4088c 100644 --- a/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst +++ b/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst @@ -1,6 +1,20 @@ pyhealth.datasets.ClinicalJargonDataset ======================================= +Public clinical jargon benchmark dataset backed by the released MedLingo and +CASI assets from Jia et al. (CHIL 2025). + +Example +------- + +.. code-block:: python + + from pyhealth.datasets import ClinicalJargonDataset + + dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon") + task = dataset.default_task + samples = dataset.set_task(task) + .. autoclass:: pyhealth.datasets.ClinicalJargonDataset :members: :undoc-members: diff --git a/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst b/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst index af6c5c186..6d0f68e45 100644 --- a/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst +++ b/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst @@ -1,6 +1,24 @@ pyhealth.tasks.ClinicalJargonVerification ========================================= +Binary candidate-verification task for the public MedLingo and CASI clinical +jargon benchmarks. + +Example +------- + +.. code-block:: python + + from pyhealth.datasets import ClinicalJargonDataset + from pyhealth.tasks import ClinicalJargonVerification + + dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon") + task = ClinicalJargonVerification( + benchmark="medlingo", + medlingo_distractors=1, + ) + samples = dataset.set_task(task) + .. autoclass:: pyhealth.tasks.ClinicalJargonVerification :members: :undoc-members: diff --git a/examples/clinical_jargon_clinical_jargon_verification_transformers.py b/examples/clinical_jargon_clinical_jargon_verification_transformers.py index 6547d7c9b..7b10f1767 100644 --- a/examples/clinical_jargon_clinical_jargon_verification_transformers.py +++ b/examples/clinical_jargon_clinical_jargon_verification_transformers.py @@ -1,3 +1,20 @@ +"""Clinical jargon verification example with `TransformersModel`. + +This example is the course-facing ablation entrypoint for the public clinical +jargon benchmark contribution. It supports three lightweight task ablations: + +- switch between ``medlingo``, ``casi``, and ``all`` benchmark subsets +- switch CASI between ``release62`` and ``paper59`` variants +- vary MedLingo distractor count through ``--medlingo-distractors`` + +Example commands: + python3 examples/clinical_jargon_clinical_jargon_verification_transformers.py \ + --benchmark medlingo --medlingo-distractors 1 --epochs 1 + + python3 examples/clinical_jargon_clinical_jargon_verification_transformers.py \ + --benchmark casi --casi-variant paper59 --epochs 1 +""" + import argparse import sys from pathlib import Path @@ -13,8 +30,20 @@ def parse_args() -> argparse.Namespace: + """Parse CLI arguments for the clinical jargon example. + + Returns: + Parsed command-line arguments for dataset root, model choice, task + configuration, and training hyperparameters. + """ parser = argparse.ArgumentParser( - description="Clinical jargon verification example with TransformersModel." + description="Clinical jargon verification example with TransformersModel.", + epilog=( + "Ablation knobs: --benchmark changes the benchmark subset, " + "--casi-variant changes the CASI filtering mode, and " + "--medlingo-distractors changes the number of negative MedLingo " + "candidates." + ), ) parser.add_argument( "--root", @@ -22,29 +51,39 @@ def parse_args() -> argparse.Namespace: default=str( PROJECT_ROOT / "test-resources" / "clinical_jargon" ), + help="Dataset root containing clinical_jargon_examples.csv or raw assets.", ) parser.add_argument( "--model-name", type=str, default="emilyalsentzer/Bio_ClinicalBERT", + help="Hugging Face model name used by TransformersModel.", ) parser.add_argument( "--benchmark", choices=["all", "medlingo", "casi"], default="medlingo", + help="Benchmark subset used for the run.", ) parser.add_argument( "--casi-variant", choices=["release62", "paper59"], default="release62", + help="CASI filtering mode when --benchmark includes CASI.", ) - parser.add_argument("--medlingo-distractors", type=int, default=3) - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument( + "--medlingo-distractors", + type=int, + default=3, + help="Number of negative MedLingo candidates retained per sample.", + ) + parser.add_argument("--epochs", type=int, default=1, help="Training epochs.") + parser.add_argument("--batch-size", type=int, default=4, help="Batch size.") return parser.parse_args() def main() -> None: + """Run one clinical jargon verification configuration.""" args = parse_args() dataset = ClinicalJargonDataset(root=args.root) task = ClinicalJargonVerification( @@ -53,6 +92,15 @@ def main() -> None: medlingo_distractors=args.medlingo_distractors, ) samples = dataset.set_task(task) + print( + { + "benchmark": args.benchmark, + "casi_variant": args.casi_variant, + "medlingo_distractors": args.medlingo_distractors, + "num_samples": len(samples), + "model_name": args.model_name, + } + ) train_dataset, val_dataset, test_dataset = split_by_sample( samples, [0.6, 0.2, 0.2], seed=42 ) diff --git a/pyhealth/datasets/clinical_jargon.py b/pyhealth/datasets/clinical_jargon.py index 21bf793d9..590ceba78 100644 --- a/pyhealth/datasets/clinical_jargon.py +++ b/pyhealth/datasets/clinical_jargon.py @@ -1,3 +1,11 @@ +"""Clinical jargon benchmark dataset for PyHealth. + +This module exposes a public clinical jargon benchmark derived from the +released MedLingo and CASI assets from Jia et al. (CHIL 2025). The dataset is +normalized into a single CSV file that PyHealth can load as an `examples` +table. +""" + import csv import json import re @@ -39,12 +47,30 @@ def split_aliases(answer: str) -> list[str]: + """Split released answer strings into canonical aliases. + + Args: + answer: The released answer string. Some MedLingo answers contain + multiple acceptable expansions joined by ``or``. + + Returns: + A non-empty list of acceptable expansion strings. + """ pieces = re.split(r"\s+or\s+", answer.strip()) aliases = [piece.strip() for piece in pieces if piece.strip()] return aliases or [answer.strip()] def surface_form_group(abbreviation: str) -> str: + """Assign a surface-form bucket to a jargon token. + + Args: + abbreviation: The shorthand token being evaluated. + + Returns: + One of ``all_caps``, ``lowercase``, ``mixed_case``, or + ``digit_or_symbol``. + """ if any(character.isdigit() or not character.isalpha() for character in abbreviation): return "digit_or_symbol" if abbreviation.isupper(): @@ -55,12 +81,28 @@ def surface_form_group(abbreviation: str) -> str: def strip_medlingo_oneshot(question: str) -> str: + """Remove the released MedLingo one-shot demonstration when present. + + Args: + question: A released MedLingo question string. + + Returns: + The same question without the built-in one-shot example prefix. + """ if question.startswith(MEDLINGO_ONESHOT_PREFIX): return question[len(MEDLINGO_ONESHOT_PREFIX) :].strip() return question.strip() def dedupe(values: list[str]) -> list[str]: + """Preserve order while removing duplicate strings. + + Args: + values: Ordered candidate strings. + + Returns: + The input values with duplicates removed in first-seen order. + """ seen: set[str] = set() ordered: list[str] = [] for value in values: @@ -71,6 +113,14 @@ def dedupe(values: list[str]) -> list[str]: def token_length(text: str) -> int: + """Count alphanumeric tokens in a string. + + Args: + text: The input text. + + Returns: + The number of regex word tokens in ``text``. + """ return len(re.findall(r"\w+", text)) @@ -79,6 +129,20 @@ def choose_medlingo_distractors( current_record: dict, distractor_count: int = 3, ) -> list[str]: + """Select distractor expansions for a MedLingo item. + + The ranking favors candidate expansions with similar token length, first + within the same surface-form group and then globally if more negatives are + needed. + + Args: + records: All normalized MedLingo records. + current_record: The record whose distractors are being chosen. + distractor_count: Number of negative candidates to return. + + Returns: + A list of distractor expansions ordered from closest to farthest match. + """ gold = current_record["gold_expansion"] goal_length = token_length(gold) @@ -110,7 +174,30 @@ def rank(pool: list[str]) -> list[str]: class ClinicalJargonDataset(BaseDataset): - """Public clinical jargon benchmark dataset for PyHealth.""" + """Public clinical jargon benchmark dataset for PyHealth. + + The dataset downloads the public MedLingo and CASI benchmark assets, + normalizes them into a single ``clinical_jargon_examples.csv`` file, and + exposes the result through the PyHealth dataset API. + + The default task is :class:`pyhealth.tasks.ClinicalJargonVerification`, + which converts each benchmark item into paired-text binary verification + samples over candidate expansions. + + Args: + root: Root directory used to store normalized benchmark files. + dataset_name: Optional dataset name. Defaults to ``clinical_jargon``. + config_path: Optional path to the dataset config file. + **kwargs: Additional keyword arguments forwarded to + :class:`pyhealth.datasets.BaseDataset`. + + Examples: + >>> from pyhealth.datasets import ClinicalJargonDataset + >>> dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon") + >>> task = dataset.default_task + >>> samples = dataset.set_task(task) + >>> print(samples[0]["paired_text"]) + """ def __init__( self, @@ -119,6 +206,14 @@ def __init__( config_path: Optional[str] = None, **kwargs, ) -> None: + """Initialize the public clinical jargon dataset. + + Args: + root: Root directory used to cache normalized files. + dataset_name: Optional dataset name override. + config_path: Optional dataset config path override. + **kwargs: Additional keyword arguments passed to ``BaseDataset``. + """ root_path = Path(root) root_path.mkdir(parents=True, exist_ok=True) if config_path is None: @@ -136,6 +231,15 @@ def __init__( @staticmethod def _download_text(url: str, destination: Path) -> str: + """Download text content unless it is already cached locally. + + Args: + url: Source URL. + destination: Cache path for the downloaded content. + + Returns: + The downloaded or cached text payload. + """ if destination.exists(): return destination.read_text() payload = urllib.request.urlopen(url).read().decode("utf-8", errors="replace") @@ -144,11 +248,27 @@ def _download_text(url: str, destination: Path) -> str: @classmethod def _fetch_medlingo_rows(cls, cache_dir: Path) -> list[dict]: + """Load raw MedLingo rows from the released public CSV. + + Args: + cache_dir: Cache directory for downloaded assets. + + Returns: + Raw MedLingo rows as dictionaries. + """ csv_text = cls._download_text(MEDLINGO_URL, cache_dir / "medlingo_questions.csv") return list(csv.DictReader(csv_text.splitlines())) @classmethod def _fetch_casi_rows(cls, cache_dir: Path) -> list[dict]: + """Load raw CASI rows from the released public subset. + + Args: + cache_dir: Cache directory for downloaded assets. + + Returns: + Raw CASI rows as dictionaries with source-file metadata. + """ index_path = cache_dir / "casi_release_index.json" if index_path.exists(): entries = json.loads(index_path.read_text()) @@ -170,6 +290,12 @@ def _fetch_casi_rows(cls, cache_dir: Path) -> list[dict]: @classmethod def prepare_metadata(cls, root: Path) -> None: + """Normalize public MedLingo and CASI assets into one CSV file. + + Args: + root: Root directory where the normalized file and cache should be + written. + """ cache_dir = root / "cache" cache_dir.mkdir(parents=True, exist_ok=True) @@ -259,6 +385,11 @@ def prepare_metadata(cls, root: Path) -> None: @property def default_task(self): + """Return the default task for the dataset. + + Returns: + ClinicalJargonVerification: The default binary verification task. + """ from ..tasks import ClinicalJargonVerification return ClinicalJargonVerification() diff --git a/pyhealth/tasks/clinical_jargon_verification.py b/pyhealth/tasks/clinical_jargon_verification.py index 19fdc487b..76ecc8657 100644 --- a/pyhealth/tasks/clinical_jargon_verification.py +++ b/pyhealth/tasks/clinical_jargon_verification.py @@ -1,18 +1,51 @@ -from typing import Any, Dict, List import json +from typing import Any, Dict, List from ..data import Patient from .base_task import BaseTask def parse_bool(value: Any) -> bool: + """Normalize released boolean-like values. + + Args: + value: A boolean or string-like value from the normalized benchmark. + + Returns: + ``True`` when the input represents a truthy flag, otherwise ``False``. + """ if isinstance(value, bool): return value return str(value).strip().lower() == "true" class ClinicalJargonVerification(BaseTask): - """Binary candidate-verification task for public clinical jargon benchmarks.""" + """Binary candidate-verification task for public clinical jargon benchmarks. + + Each sample pairs a benchmark question with one candidate expansion and + asks whether that candidate is correct. This reframes the released jargon + benchmark into a standard PyHealth binary classification task. + + Args: + benchmark: Benchmark subset to include. One of ``all``, ``medlingo``, + or ``casi``. + casi_variant: Which CASI view to expose. ``release62`` uses the public + release as-is, while ``paper59`` applies the local exclusion set + used to approximate the paper's filtered benchmark. + medlingo_distractors: Number of negative MedLingo candidates to retain + per example. Valid values are ``1``, ``2``, and ``3``. + + Examples: + >>> from pyhealth.datasets import ClinicalJargonDataset + >>> from pyhealth.tasks import ClinicalJargonVerification + >>> dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon") + >>> task = ClinicalJargonVerification( + ... benchmark="medlingo", + ... medlingo_distractors=1, + ... ) + >>> samples = dataset.set_task(task) + >>> print(samples[0]["label"]) + """ task_name: str = "ClinicalJargonVerification" input_schema: Dict[str, str] = {"paired_text": "text"} @@ -24,6 +57,16 @@ def __init__( casi_variant: str = "release62", medlingo_distractors: int = 3, ) -> None: + """Initialize the clinical jargon verification task. + + Args: + benchmark: Benchmark subset to include. + casi_variant: CASI variant to expose. + medlingo_distractors: Number of negative MedLingo candidates. + + Raises: + ValueError: If any task configuration argument is unsupported. + """ if benchmark not in {"all", "medlingo", "casi"}: raise ValueError(f"Unsupported benchmark: {benchmark}") if casi_variant not in {"release62", "paper59"}: @@ -35,6 +78,17 @@ def __init__( self.medlingo_distractors = medlingo_distractors def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Generate verification samples for one patient record. + + Args: + patient: A PyHealth patient object containing one ``examples`` + event from the normalized clinical jargon dataset. + + Returns: + A list of binary verification samples. MedLingo records emit one + positive and ``medlingo_distractors`` negative candidates. CASI + emits all released candidates for the selected variant. + """ events = patient.get_events(event_type="examples") if len(events) != 1: return [] From 59ffe38668672cb9ad37c4a21859dd26b79c0a89 Mon Sep 17 00:00:00 2001 From: John Carson Date: Thu, 9 Apr 2026 18:49:35 -0400 Subject: [PATCH 3/4] Address Copilot review feedback for clinical jargon --- ...yhealth.datasets.ClinicalJargonDataset.rst | 2 +- ...ealth.tasks.ClinicalJargonVerification.rst | 2 +- ...inical_jargon_verification_transformers.py | 15 +++-- pyhealth/datasets/clinical_jargon.py | 56 ++++++++++++++----- tests/core/test_clinical_jargon.py | 15 +++++ 5 files changed, 68 insertions(+), 22 deletions(-) diff --git a/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst b/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst index 2f8e4088c..2c98a0637 100644 --- a/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst +++ b/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst @@ -11,7 +11,7 @@ Example from pyhealth.datasets import ClinicalJargonDataset - dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon") + dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon", download=True) task = dataset.default_task samples = dataset.set_task(task) diff --git a/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst b/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst index 6d0f68e45..93d9c8a95 100644 --- a/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst +++ b/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst @@ -12,7 +12,7 @@ Example from pyhealth.datasets import ClinicalJargonDataset from pyhealth.tasks import ClinicalJargonVerification - dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon") + dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon", download=True) task = ClinicalJargonVerification( benchmark="medlingo", medlingo_distractors=1, diff --git a/examples/clinical_jargon_clinical_jargon_verification_transformers.py b/examples/clinical_jargon_clinical_jargon_verification_transformers.py index 7b10f1767..ea69d08f5 100644 --- a/examples/clinical_jargon_clinical_jargon_verification_transformers.py +++ b/examples/clinical_jargon_clinical_jargon_verification_transformers.py @@ -7,6 +7,9 @@ - switch CASI between ``release62`` and ``paper59`` variants - vary MedLingo distractor count through ``--medlingo-distractors`` +Run this example from an environment where PyHealth is installed, such as +``pip install -e .`` from the repository root. + Example commands: python3 examples/clinical_jargon_clinical_jargon_verification_transformers.py \ --benchmark medlingo --medlingo-distractors 1 --epochs 1 @@ -16,12 +19,9 @@ """ import argparse -import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parents[1] -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) from pyhealth.datasets import ClinicalJargonDataset, get_dataloader, split_by_sample from pyhealth.models.transformers_model import TransformersModel @@ -51,7 +51,12 @@ def parse_args() -> argparse.Namespace: default=str( PROJECT_ROOT / "test-resources" / "clinical_jargon" ), - help="Dataset root containing clinical_jargon_examples.csv or raw assets.", + help="Dataset root containing clinical_jargon_examples.csv.", + ) + parser.add_argument( + "--download", + action="store_true", + help="Fetch the public source assets when the normalized CSV is missing.", ) parser.add_argument( "--model-name", @@ -85,7 +90,7 @@ def parse_args() -> argparse.Namespace: def main() -> None: """Run one clinical jargon verification configuration.""" args = parse_args() - dataset = ClinicalJargonDataset(root=args.root) + dataset = ClinicalJargonDataset(root=args.root, download=args.download) task = ClinicalJargonVerification( benchmark=args.benchmark, casi_variant=args.casi_variant, diff --git a/pyhealth/datasets/clinical_jargon.py b/pyhealth/datasets/clinical_jargon.py index 590ceba78..e71c59568 100644 --- a/pyhealth/datasets/clinical_jargon.py +++ b/pyhealth/datasets/clinical_jargon.py @@ -28,6 +28,7 @@ MEDLINGO_ONESHOT_PREFIX = ( "In a clinical note that mentions a high creat, creat stands for creatine. " ) +DOWNLOAD_TIMEOUT_SECONDS = 10 PAPER59_EXCLUSIONS = frozenset( { ("AB", "blood group in ABO system"), @@ -188,12 +189,18 @@ class ClinicalJargonDataset(BaseDataset): root: Root directory used to store normalized benchmark files. dataset_name: Optional dataset name. Defaults to ``clinical_jargon``. config_path: Optional path to the dataset config file. + download: Whether to download and normalize the public source assets + when ``clinical_jargon_examples.csv`` is missing. Defaults to + ``False``. **kwargs: Additional keyword arguments forwarded to :class:`pyhealth.datasets.BaseDataset`. Examples: >>> from pyhealth.datasets import ClinicalJargonDataset - >>> dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon") + >>> dataset = ClinicalJargonDataset( + ... root="/tmp/clinical_jargon", + ... download=True, + ... ) >>> task = dataset.default_task >>> samples = dataset.set_task(task) >>> print(samples[0]["paired_text"]) @@ -204,6 +211,7 @@ def __init__( root: str, dataset_name: Optional[str] = None, config_path: Optional[str] = None, + download: bool = False, **kwargs, ) -> None: """Initialize the public clinical jargon dataset. @@ -212,6 +220,8 @@ def __init__( root: Root directory used to cache normalized files. dataset_name: Optional dataset name override. config_path: Optional dataset config path override. + download: Whether to fetch and normalize the released benchmark + assets when the normalized CSV is missing. **kwargs: Additional keyword arguments passed to ``BaseDataset``. """ root_path = Path(root) @@ -220,6 +230,12 @@ def __init__( config_path = Path(__file__).parent / "configs" / "clinical_jargon.yaml" normalized_csv = root_path / "clinical_jargon_examples.csv" if not normalized_csv.exists(): + if not download: + raise FileNotFoundError( + f"Missing normalized metadata at {normalized_csv}. " + "Pass download=True to fetch the public MedLingo and CASI " + "assets and generate this CSV." + ) self.prepare_metadata(root_path) super().__init__( root=str(root_path), @@ -241,11 +257,29 @@ def _download_text(url: str, destination: Path) -> str: The downloaded or cached text payload. """ if destination.exists(): - return destination.read_text() - payload = urllib.request.urlopen(url).read().decode("utf-8", errors="replace") - destination.write_text(payload) + return destination.read_text(encoding="utf-8", errors="replace") + request = urllib.request.Request(url) + with urllib.request.urlopen( + request, + timeout=DOWNLOAD_TIMEOUT_SECONDS, + ) as response: + payload = response.read().decode("utf-8", errors="replace") + destination.write_text(payload, encoding="utf-8", errors="replace") return payload + @staticmethod + def _validated_file_name(file_name: str) -> str: + """Validate a remotely provided cache filename.""" + candidate = Path(file_name) + if ( + not file_name + or candidate.is_absolute() + or candidate.name != file_name + or candidate.parent != Path(".") + ): + raise ValueError(f"Invalid cache file name: {file_name}") + return file_name + @classmethod def _fetch_medlingo_rows(cls, cache_dir: Path) -> list[dict]: """Load raw MedLingo rows from the released public CSV. @@ -270,18 +304,10 @@ def _fetch_casi_rows(cls, cache_dir: Path) -> list[dict]: Raw CASI rows as dictionaries with source-file metadata. """ index_path = cache_dir / "casi_release_index.json" - if index_path.exists(): - entries = json.loads(index_path.read_text()) - else: - entries = json.loads( - urllib.request.urlopen(CASI_RELEASE_INDEX_URL) - .read() - .decode("utf-8", errors="replace") - ) - index_path.write_text(json.dumps(entries, indent=2)) + entries = json.loads(cls._download_text(CASI_RELEASE_INDEX_URL, index_path)) rows: list[dict] = [] for entry in entries: - file_name = entry["name"] + file_name = cls._validated_file_name(entry["name"]) file_text = cls._download_text(entry["download_url"], cache_dir / file_name) for row in csv.DictReader(file_text.splitlines()): row["source_file"] = file_name @@ -378,7 +404,7 @@ def prepare_metadata(cls, root: Path) -> None: ) output_path = root / "clinical_jargon_examples.csv" - with output_path.open("w", newline="") as handle: + with output_path.open("w", newline="", encoding="utf-8") as handle: writer = csv.DictWriter(handle, fieldnames=list(normalized_rows[0].keys())) writer.writeheader() writer.writerows(normalized_rows) diff --git a/tests/core/test_clinical_jargon.py b/tests/core/test_clinical_jargon.py index c7cea7f3e..6adc55e9a 100644 --- a/tests/core/test_clinical_jargon.py +++ b/tests/core/test_clinical_jargon.py @@ -1,6 +1,7 @@ import unittest from pathlib import Path from tempfile import TemporaryDirectory +from unittest.mock import patch from pyhealth.datasets import ClinicalJargonDataset from pyhealth.tasks import ClinicalJargonVerification @@ -30,6 +31,11 @@ def test_dataset_initialization(self): dataset = self.make_dataset() self.assertEqual(dataset.dataset_name, "clinical_jargon") + def test_missing_csv_requires_explicit_download(self): + with TemporaryDirectory() as root: + with self.assertRaises(FileNotFoundError): + ClinicalJargonDataset(root=root, cache_dir=self.cache_dir.name) + def test_default_task(self): dataset = self.make_dataset() self.assertIsInstance(dataset.default_task, ClinicalJargonVerification) @@ -74,6 +80,15 @@ def test_task_output_shape(self): self.assertIn("label", sample) self.assertIn("candidate_expansion", sample) + def test_rejects_non_local_casi_cache_names(self): + with patch.object( + ClinicalJargonDataset, + "_download_text", + return_value='[{"name":"../escape.csv","download_url":"https://example.com"}]', + ): + with self.assertRaises(ValueError): + ClinicalJargonDataset._fetch_casi_rows(Path(self.cache_dir.name)) + if __name__ == "__main__": unittest.main() From f9540dfa80b60ef55974ce2efe6a8db161671319 Mon Sep 17 00:00:00 2001 From: John Carson Date: Thu, 9 Apr 2026 19:04:52 -0400 Subject: [PATCH 4/4] Prevent split leakage in clinical jargon example --- ...inical_jargon_verification_transformers.py | 7 ++--- .../clinical_jargon_examples.csv | 1 + tests/core/test_clinical_jargon.py | 26 +++++++++++++++---- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/examples/clinical_jargon_clinical_jargon_verification_transformers.py b/examples/clinical_jargon_clinical_jargon_verification_transformers.py index ea69d08f5..0d985c020 100644 --- a/examples/clinical_jargon_clinical_jargon_verification_transformers.py +++ b/examples/clinical_jargon_clinical_jargon_verification_transformers.py @@ -1,7 +1,8 @@ """Clinical jargon verification example with `TransformersModel`. This example is the course-facing ablation entrypoint for the public clinical -jargon benchmark contribution. It supports three lightweight task ablations: +jargon benchmark contribution. It supports three lightweight task ablations +while keeping each source benchmark item in a single split: - switch between ``medlingo``, ``casi``, and ``all`` benchmark subsets - switch CASI between ``release62`` and ``paper59`` variants @@ -23,7 +24,7 @@ PROJECT_ROOT = Path(__file__).resolve().parents[1] -from pyhealth.datasets import ClinicalJargonDataset, get_dataloader, split_by_sample +from pyhealth.datasets import ClinicalJargonDataset, get_dataloader, split_by_patient from pyhealth.models.transformers_model import TransformersModel from pyhealth.tasks import ClinicalJargonVerification from pyhealth.trainer import Trainer @@ -106,7 +107,7 @@ def main() -> None: "model_name": args.model_name, } ) - train_dataset, val_dataset, test_dataset = split_by_sample( + train_dataset, val_dataset, test_dataset = split_by_patient( samples, [0.6, 0.2, 0.2], seed=42 ) train_loader = get_dataloader(train_dataset, batch_size=args.batch_size, shuffle=True) diff --git a/test-resources/clinical_jargon/clinical_jargon_examples.csv b/test-resources/clinical_jargon/clinical_jargon_examples.csv index d2ea2dd18..2ccf31feb 100644 --- a/test-resources/clinical_jargon/clinical_jargon_examples.csv +++ b/test-resources/clinical_jargon/clinical_jargon_examples.csv @@ -1,6 +1,7 @@ sample_id,benchmark,abbreviation,context,question,question_zero_shot,gold_expansion,gold_aliases_json,surface_form_group,paper59_included,source_file,candidate_expansions_json,candidate_expansions_paper59_json medlingo_001,medlingo,PRN,,"In a clinical note that mentions a high creat, creat stands for creatine. In a clinical note that mentions PRN, PRN stands for","In a clinical note that mentions PRN, PRN stands for",as needed,"[""as needed""]",all_caps,true,medlingo_questions.csv,"[""as needed"", ""vital signs"", ""hypertension"", ""abdomen""]","[""as needed"", ""vital signs"", ""hypertension"", ""abdomen""]" medlingo_002,medlingo,VS,,"In a clinical note that mentions a high creat, creat stands for creatine. In a clinical note that mentions VS every four hours, VS stands for","In a clinical note that mentions VS every four hours, VS stands for",vital signs,"[""vital signs""]",all_caps,true,medlingo_questions.csv,"[""vital signs"", ""as needed"", ""hypertension"", ""abdomen""]","[""vital signs"", ""as needed"", ""hypertension"", ""abdomen""]" +medlingo_003,medlingo,HTN,,"In a clinical note that mentions a high creat, creat stands for creatine. In a clinical note that mentions uncontrolled HTN, HTN stands for","In a clinical note that mentions uncontrolled HTN, HTN stands for",hypertension,"[""hypertension""]",all_caps,true,medlingo_questions.csv,"[""hypertension"", ""as needed"", ""vital signs"", ""abdomen""]","[""hypertension"", ""as needed"", ""vital signs"", ""abdomen""]" casi_0001,casi,AB,"Elective AB in 1989.","Elective AB in 1989. In this sentence, AB means:","Elective AB in 1989. In this sentence, AB means:",abortion,"[""abortion""]",all_caps,true,cleaned_dataset_AB.csv,"[""abortion"", ""blood group in ABO system""]","[""abortion""]" casi_0002,casi,AB,"Type and screen: AB positive.","Type and screen: AB positive. In this sentence, AB means:","Type and screen: AB positive. In this sentence, AB means:","blood group in ABO system","[""blood group in ABO system""]",all_caps,false,cleaned_dataset_AB.csv,"[""abortion"", ""blood group in ABO system""]","[""abortion""]" casi_0003,casi,US,"US abdomen ordered.","US abdomen ordered. In this sentence, US means:","US abdomen ordered. In this sentence, US means:",ultrasound,"[""ultrasound""]",all_caps,true,cleaned_dataset_US.csv,"[""United States"", ""ultrasound""]","[""ultrasound""]" diff --git a/tests/core/test_clinical_jargon.py b/tests/core/test_clinical_jargon.py index 6adc55e9a..4c97cd567 100644 --- a/tests/core/test_clinical_jargon.py +++ b/tests/core/test_clinical_jargon.py @@ -3,7 +3,7 @@ from tempfile import TemporaryDirectory from unittest.mock import patch -from pyhealth.datasets import ClinicalJargonDataset +from pyhealth.datasets import ClinicalJargonDataset, split_by_patient from pyhealth.tasks import ClinicalJargonVerification @@ -42,7 +42,7 @@ def test_default_task(self): def test_num_patients(self): dataset = self.make_dataset() - self.assertEqual(len(dataset.unique_patient_ids), 8) + self.assertEqual(len(dataset.unique_patient_ids), 9) def test_get_patient(self): dataset = self.make_dataset() @@ -54,13 +54,13 @@ def test_release62_task_samples(self): dataset = self.make_dataset() task = ClinicalJargonVerification(benchmark="all", casi_variant="release62") samples = dataset.set_task(task) - self.assertEqual(len(samples), 20) + self.assertEqual(len(samples), 24) def test_paper59_task_samples(self): dataset = self.make_dataset() task = ClinicalJargonVerification(benchmark="all", casi_variant="paper59") samples = dataset.set_task(task) - self.assertEqual(len(samples), 14) + self.assertEqual(len(samples), 18) def test_medlingo_distractor_control(self): dataset = self.make_dataset() @@ -69,7 +69,23 @@ def test_medlingo_distractor_control(self): medlingo_distractors=1, ) samples = dataset.set_task(task) - self.assertEqual(len(samples), 4) + self.assertEqual(len(samples), 6) + + def test_split_by_patient_keeps_candidate_rows_together(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification( + benchmark="medlingo", + medlingo_distractors=1, + ) + samples = dataset.set_task(task) + splits = split_by_patient(samples, [0.6, 0.2, 0.2], seed=42) + split_patient_ids = [ + {split_dataset[index]["patient_id"] for index in range(len(split_dataset))} + for split_dataset in splits + ] + for left_index, left_ids in enumerate(split_patient_ids): + for right_ids in split_patient_ids[left_index + 1 :]: + self.assertTrue(left_ids.isdisjoint(right_ids)) def test_task_output_shape(self): dataset = self.make_dataset()