diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..ec0166994 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -240,6 +240,7 @@ Available Datasets datasets/pyhealth.datasets.ChestXray14Dataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset + datasets/pyhealth.datasets.ClinicalJargonDataset datasets/pyhealth.datasets.ClinVarDataset datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset diff --git a/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst b/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst new file mode 100644 index 000000000..2c98a0637 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst @@ -0,0 +1,21 @@ +pyhealth.datasets.ClinicalJargonDataset +======================================= + +Public clinical jargon benchmark dataset backed by the released MedLingo and +CASI assets from Jia et al. (CHIL 2025). + +Example +------- + +.. code-block:: python + + from pyhealth.datasets import ClinicalJargonDataset + + dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon", download=True) + task = dataset.default_task + samples = dataset.set_task(task) + +.. autoclass:: pyhealth.datasets.ClinicalJargonDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index d85d04bc3..9cf0ac729 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -227,6 +227,7 @@ Available Tasks Benchmark EHRShot ChestX-ray14 Binary Classification ChestX-ray14 Multilabel Classification + Clinical Jargon Verification Variant Classification (ClinVar) Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) diff --git a/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst b/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst new file mode 100644 index 000000000..93d9c8a95 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst @@ -0,0 +1,25 @@ +pyhealth.tasks.ClinicalJargonVerification +========================================= + +Binary candidate-verification task for the public MedLingo and CASI clinical +jargon benchmarks. + +Example +------- + +.. code-block:: python + + from pyhealth.datasets import ClinicalJargonDataset + from pyhealth.tasks import ClinicalJargonVerification + + dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon", download=True) + task = ClinicalJargonVerification( + benchmark="medlingo", + medlingo_distractors=1, + ) + samples = dataset.set_task(task) + +.. autoclass:: pyhealth.tasks.ClinicalJargonVerification + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/clinical_jargon_clinical_jargon_verification_transformers.py b/examples/clinical_jargon_clinical_jargon_verification_transformers.py new file mode 100644 index 000000000..0d985c020 --- /dev/null +++ b/examples/clinical_jargon_clinical_jargon_verification_transformers.py @@ -0,0 +1,130 @@ +"""Clinical jargon verification example with `TransformersModel`. + +This example is the course-facing ablation entrypoint for the public clinical +jargon benchmark contribution. It supports three lightweight task ablations +while keeping each source benchmark item in a single split: + +- switch between ``medlingo``, ``casi``, and ``all`` benchmark subsets +- switch CASI between ``release62`` and ``paper59`` variants +- vary MedLingo distractor count through ``--medlingo-distractors`` + +Run this example from an environment where PyHealth is installed, such as +``pip install -e .`` from the repository root. + +Example commands: + python3 examples/clinical_jargon_clinical_jargon_verification_transformers.py \ + --benchmark medlingo --medlingo-distractors 1 --epochs 1 + + python3 examples/clinical_jargon_clinical_jargon_verification_transformers.py \ + --benchmark casi --casi-variant paper59 --epochs 1 +""" + +import argparse +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parents[1] + +from pyhealth.datasets import ClinicalJargonDataset, get_dataloader, split_by_patient +from pyhealth.models.transformers_model import TransformersModel +from pyhealth.tasks import ClinicalJargonVerification +from pyhealth.trainer import Trainer + + +def parse_args() -> argparse.Namespace: + """Parse CLI arguments for the clinical jargon example. + + Returns: + Parsed command-line arguments for dataset root, model choice, task + configuration, and training hyperparameters. + """ + parser = argparse.ArgumentParser( + description="Clinical jargon verification example with TransformersModel.", + epilog=( + "Ablation knobs: --benchmark changes the benchmark subset, " + "--casi-variant changes the CASI filtering mode, and " + "--medlingo-distractors changes the number of negative MedLingo " + "candidates." + ), + ) + parser.add_argument( + "--root", + type=str, + default=str( + PROJECT_ROOT / "test-resources" / "clinical_jargon" + ), + help="Dataset root containing clinical_jargon_examples.csv.", + ) + parser.add_argument( + "--download", + action="store_true", + help="Fetch the public source assets when the normalized CSV is missing.", + ) + parser.add_argument( + "--model-name", + type=str, + default="emilyalsentzer/Bio_ClinicalBERT", + help="Hugging Face model name used by TransformersModel.", + ) + parser.add_argument( + "--benchmark", + choices=["all", "medlingo", "casi"], + default="medlingo", + help="Benchmark subset used for the run.", + ) + parser.add_argument( + "--casi-variant", + choices=["release62", "paper59"], + default="release62", + help="CASI filtering mode when --benchmark includes CASI.", + ) + parser.add_argument( + "--medlingo-distractors", + type=int, + default=3, + help="Number of negative MedLingo candidates retained per sample.", + ) + parser.add_argument("--epochs", type=int, default=1, help="Training epochs.") + parser.add_argument("--batch-size", type=int, default=4, help="Batch size.") + return parser.parse_args() + + +def main() -> None: + """Run one clinical jargon verification configuration.""" + args = parse_args() + dataset = ClinicalJargonDataset(root=args.root, download=args.download) + task = ClinicalJargonVerification( + benchmark=args.benchmark, + casi_variant=args.casi_variant, + medlingo_distractors=args.medlingo_distractors, + ) + samples = dataset.set_task(task) + print( + { + "benchmark": args.benchmark, + "casi_variant": args.casi_variant, + "medlingo_distractors": args.medlingo_distractors, + "num_samples": len(samples), + "model_name": args.model_name, + } + ) + train_dataset, val_dataset, test_dataset = split_by_patient( + samples, [0.6, 0.2, 0.2], seed=42 + ) + train_loader = get_dataloader(train_dataset, batch_size=args.batch_size, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=args.batch_size, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=args.batch_size, shuffle=False) + + model = TransformersModel(dataset=samples, model_name=args.model_name) + trainer = Trainer(model=model, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + test_dataloader=test_loader, + epochs=args.epochs, + ) + scores = trainer.evaluate(test_loader) + print(scores) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 3edbe06f7..a679a2317 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -65,6 +65,7 @@ def __init__(self, *args, **kwargs): from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset +from .clinical_jargon import ClinicalJargonDataset from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset from .splitter import ( diff --git a/pyhealth/datasets/clinical_jargon.py b/pyhealth/datasets/clinical_jargon.py new file mode 100644 index 000000000..e71c59568 --- /dev/null +++ b/pyhealth/datasets/clinical_jargon.py @@ -0,0 +1,421 @@ +"""Clinical jargon benchmark dataset for PyHealth. + +This module exposes a public clinical jargon benchmark derived from the +released MedLingo and CASI assets from Jia et al. (CHIL 2025). The dataset is +normalized into a single CSV file that PyHealth can load as an `examples` +table. +""" + +import csv +import json +import re +import urllib.request +from collections import defaultdict +from pathlib import Path +from typing import Optional + +from .base_dataset import BaseDataset + + +MEDLINGO_URL = ( + "https://raw.githubusercontent.com/Flora-jia-jfr/diagnosing_our_datasets/" + "main/datasets/MedLingo/questions.csv" +) +CASI_RELEASE_INDEX_URL = ( + "https://api.github.com/repos/Flora-jia-jfr/diagnosing_our_datasets/contents/" + "datasets/casi/cleaned_dataset_subset?ref=main" +) +MEDLINGO_ONESHOT_PREFIX = ( + "In a clinical note that mentions a high creat, creat stands for creatine. " +) +DOWNLOAD_TIMEOUT_SECONDS = 10 +PAPER59_EXCLUSIONS = frozenset( + { + ("AB", "blood group in ABO system"), + ("US", "United States"), + ("IB", "international baccalaureate"), + ("MS", "master of science"), + ("MP", "military police"), + ("PD", "police department"), + ("MP", "metatarsophalangeal/metacarpophalangeal"), + ("OP", "oblique presentation/occiput posterior"), + ("SA", "slow acting/sustained action"), + ("C&S", "conjunctivae and sclerae"), + ("C&S", "culture and sensitivity"), + ("C&S", "protein C and protein S"), + } +) + + +def split_aliases(answer: str) -> list[str]: + """Split released answer strings into canonical aliases. + + Args: + answer: The released answer string. Some MedLingo answers contain + multiple acceptable expansions joined by ``or``. + + Returns: + A non-empty list of acceptable expansion strings. + """ + pieces = re.split(r"\s+or\s+", answer.strip()) + aliases = [piece.strip() for piece in pieces if piece.strip()] + return aliases or [answer.strip()] + + +def surface_form_group(abbreviation: str) -> str: + """Assign a surface-form bucket to a jargon token. + + Args: + abbreviation: The shorthand token being evaluated. + + Returns: + One of ``all_caps``, ``lowercase``, ``mixed_case``, or + ``digit_or_symbol``. + """ + if any(character.isdigit() or not character.isalpha() for character in abbreviation): + return "digit_or_symbol" + if abbreviation.isupper(): + return "all_caps" + if abbreviation.islower(): + return "lowercase" + return "mixed_case" + + +def strip_medlingo_oneshot(question: str) -> str: + """Remove the released MedLingo one-shot demonstration when present. + + Args: + question: A released MedLingo question string. + + Returns: + The same question without the built-in one-shot example prefix. + """ + if question.startswith(MEDLINGO_ONESHOT_PREFIX): + return question[len(MEDLINGO_ONESHOT_PREFIX) :].strip() + return question.strip() + + +def dedupe(values: list[str]) -> list[str]: + """Preserve order while removing duplicate strings. + + Args: + values: Ordered candidate strings. + + Returns: + The input values with duplicates removed in first-seen order. + """ + seen: set[str] = set() + ordered: list[str] = [] + for value in values: + if value not in seen: + seen.add(value) + ordered.append(value) + return ordered + + +def token_length(text: str) -> int: + """Count alphanumeric tokens in a string. + + Args: + text: The input text. + + Returns: + The number of regex word tokens in ``text``. + """ + return len(re.findall(r"\w+", text)) + + +def choose_medlingo_distractors( + records: list[dict], + current_record: dict, + distractor_count: int = 3, +) -> list[str]: + """Select distractor expansions for a MedLingo item. + + The ranking favors candidate expansions with similar token length, first + within the same surface-form group and then globally if more negatives are + needed. + + Args: + records: All normalized MedLingo records. + current_record: The record whose distractors are being chosen. + distractor_count: Number of negative candidates to return. + + Returns: + A list of distractor expansions ordered from closest to farthest match. + """ + gold = current_record["gold_expansion"] + goal_length = token_length(gold) + + def rank(pool: list[str]) -> list[str]: + return sorted( + dedupe([value for value in pool if value != gold]), + key=lambda value: (abs(token_length(value) - goal_length), value.lower()), + ) + + same_group = [ + record["gold_expansion"] + for record in records + if record["sample_id"] != current_record["sample_id"] + and record["surface_form_group"] == current_record["surface_form_group"] + ] + global_pool = [ + record["gold_expansion"] + for record in records + if record["sample_id"] != current_record["sample_id"] + ] + negatives = rank(same_group) + if len(negatives) < distractor_count: + for candidate in rank(global_pool): + if candidate not in negatives: + negatives.append(candidate) + if len(negatives) == distractor_count: + break + return negatives[:distractor_count] + + +class ClinicalJargonDataset(BaseDataset): + """Public clinical jargon benchmark dataset for PyHealth. + + The dataset downloads the public MedLingo and CASI benchmark assets, + normalizes them into a single ``clinical_jargon_examples.csv`` file, and + exposes the result through the PyHealth dataset API. + + The default task is :class:`pyhealth.tasks.ClinicalJargonVerification`, + which converts each benchmark item into paired-text binary verification + samples over candidate expansions. + + Args: + root: Root directory used to store normalized benchmark files. + dataset_name: Optional dataset name. Defaults to ``clinical_jargon``. + config_path: Optional path to the dataset config file. + download: Whether to download and normalize the public source assets + when ``clinical_jargon_examples.csv`` is missing. Defaults to + ``False``. + **kwargs: Additional keyword arguments forwarded to + :class:`pyhealth.datasets.BaseDataset`. + + Examples: + >>> from pyhealth.datasets import ClinicalJargonDataset + >>> dataset = ClinicalJargonDataset( + ... root="/tmp/clinical_jargon", + ... download=True, + ... ) + >>> task = dataset.default_task + >>> samples = dataset.set_task(task) + >>> print(samples[0]["paired_text"]) + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + download: bool = False, + **kwargs, + ) -> None: + """Initialize the public clinical jargon dataset. + + Args: + root: Root directory used to cache normalized files. + dataset_name: Optional dataset name override. + config_path: Optional dataset config path override. + download: Whether to fetch and normalize the released benchmark + assets when the normalized CSV is missing. + **kwargs: Additional keyword arguments passed to ``BaseDataset``. + """ + root_path = Path(root) + root_path.mkdir(parents=True, exist_ok=True) + if config_path is None: + config_path = Path(__file__).parent / "configs" / "clinical_jargon.yaml" + normalized_csv = root_path / "clinical_jargon_examples.csv" + if not normalized_csv.exists(): + if not download: + raise FileNotFoundError( + f"Missing normalized metadata at {normalized_csv}. " + "Pass download=True to fetch the public MedLingo and CASI " + "assets and generate this CSV." + ) + self.prepare_metadata(root_path) + super().__init__( + root=str(root_path), + tables=["examples"], + dataset_name=dataset_name or "clinical_jargon", + config_path=str(config_path), + **kwargs, + ) + + @staticmethod + def _download_text(url: str, destination: Path) -> str: + """Download text content unless it is already cached locally. + + Args: + url: Source URL. + destination: Cache path for the downloaded content. + + Returns: + The downloaded or cached text payload. + """ + if destination.exists(): + return destination.read_text(encoding="utf-8", errors="replace") + request = urllib.request.Request(url) + with urllib.request.urlopen( + request, + timeout=DOWNLOAD_TIMEOUT_SECONDS, + ) as response: + payload = response.read().decode("utf-8", errors="replace") + destination.write_text(payload, encoding="utf-8", errors="replace") + return payload + + @staticmethod + def _validated_file_name(file_name: str) -> str: + """Validate a remotely provided cache filename.""" + candidate = Path(file_name) + if ( + not file_name + or candidate.is_absolute() + or candidate.name != file_name + or candidate.parent != Path(".") + ): + raise ValueError(f"Invalid cache file name: {file_name}") + return file_name + + @classmethod + def _fetch_medlingo_rows(cls, cache_dir: Path) -> list[dict]: + """Load raw MedLingo rows from the released public CSV. + + Args: + cache_dir: Cache directory for downloaded assets. + + Returns: + Raw MedLingo rows as dictionaries. + """ + csv_text = cls._download_text(MEDLINGO_URL, cache_dir / "medlingo_questions.csv") + return list(csv.DictReader(csv_text.splitlines())) + + @classmethod + def _fetch_casi_rows(cls, cache_dir: Path) -> list[dict]: + """Load raw CASI rows from the released public subset. + + Args: + cache_dir: Cache directory for downloaded assets. + + Returns: + Raw CASI rows as dictionaries with source-file metadata. + """ + index_path = cache_dir / "casi_release_index.json" + entries = json.loads(cls._download_text(CASI_RELEASE_INDEX_URL, index_path)) + rows: list[dict] = [] + for entry in entries: + file_name = cls._validated_file_name(entry["name"]) + file_text = cls._download_text(entry["download_url"], cache_dir / file_name) + for row in csv.DictReader(file_text.splitlines()): + row["source_file"] = file_name + rows.append(row) + return rows + + @classmethod + def prepare_metadata(cls, root: Path) -> None: + """Normalize public MedLingo and CASI assets into one CSV file. + + Args: + root: Root directory where the normalized file and cache should be + written. + """ + cache_dir = root / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + + medlingo_rows: list[dict] = [] + for index, row in enumerate(cls._fetch_medlingo_rows(cache_dir), start=1): + aliases = split_aliases(row["answer"]) + medlingo_rows.append( + { + "sample_id": f"medlingo_{index:03d}", + "benchmark": "medlingo", + "abbreviation": row["word1"], + "context": "", + "question": row["question"].strip(), + "question_zero_shot": strip_medlingo_oneshot(row["question"]), + "gold_expansion": aliases[0], + "gold_aliases_json": json.dumps(aliases), + "surface_form_group": surface_form_group(row["word1"]), + "paper59_included": "true", + "source_file": "medlingo_questions.csv", + } + ) + + medlingo_candidate_map = {} + for row in medlingo_rows: + negatives = choose_medlingo_distractors(medlingo_rows, row, distractor_count=3) + medlingo_candidate_map[row["sample_id"]] = [row["gold_expansion"], *negatives] + + for row in medlingo_rows: + row["candidate_expansions_json"] = json.dumps( + medlingo_candidate_map[row["sample_id"]] + ) + row["candidate_expansions_paper59_json"] = row["candidate_expansions_json"] + + casi_rows = cls._fetch_casi_rows(cache_dir) + release_expansions: dict[str, list[str]] = defaultdict(list) + paper59_expansions: dict[str, list[str]] = defaultdict(list) + for row in casi_rows: + abbreviation = row["sf"] + expansion = row["target_lf"] + release_expansions[abbreviation].append(expansion) + if (abbreviation, expansion) not in PAPER59_EXCLUSIONS: + paper59_expansions[abbreviation].append(expansion) + + release_expansions = { + key: dedupe(sorted(values, key=str.lower)) + for key, values in release_expansions.items() + } + paper59_expansions = { + key: dedupe(sorted(values, key=str.lower)) + for key, values in paper59_expansions.items() + } + + normalized_rows = list(medlingo_rows) + for index, row in enumerate(casi_rows, start=1): + abbreviation = row["sf"] + expansion = row["target_lf"] + question = f"{row['context'].strip()} In this sentence, {abbreviation} means:" + normalized_rows.append( + { + "sample_id": f"casi_{index:04d}", + "benchmark": "casi", + "abbreviation": abbreviation, + "context": row["context"].strip(), + "question": question, + "question_zero_shot": question, + "gold_expansion": expansion, + "gold_aliases_json": json.dumps([expansion]), + "surface_form_group": surface_form_group(abbreviation), + "paper59_included": str( + (abbreviation, expansion) not in PAPER59_EXCLUSIONS + ).lower(), + "source_file": row["source_file"], + "candidate_expansions_json": json.dumps( + release_expansions[abbreviation] + ), + "candidate_expansions_paper59_json": json.dumps( + paper59_expansions.get(abbreviation, []) + ), + } + ) + + output_path = root / "clinical_jargon_examples.csv" + with output_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=list(normalized_rows[0].keys())) + writer.writeheader() + writer.writerows(normalized_rows) + + @property + def default_task(self): + """Return the default task for the dataset. + + Returns: + ClinicalJargonVerification: The default binary verification task. + """ + from ..tasks import ClinicalJargonVerification + + return ClinicalJargonVerification() diff --git a/pyhealth/datasets/configs/clinical_jargon.yaml b/pyhealth/datasets/configs/clinical_jargon.yaml new file mode 100644 index 000000000..9592c5114 --- /dev/null +++ b/pyhealth/datasets/configs/clinical_jargon.yaml @@ -0,0 +1,20 @@ +version: "1.0" +tables: + examples: + file_path: "clinical_jargon_examples.csv" + patient_id: null + timestamp: null + attributes: + - "sample_id" + - "benchmark" + - "abbreviation" + - "context" + - "question" + - "question_zero_shot" + - "gold_expansion" + - "gold_aliases_json" + - "surface_form_group" + - "paper59_included" + - "source_file" + - "candidate_expansions_json" + - "candidate_expansions_paper59_json" diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..e90507df7 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -11,6 +11,7 @@ ) from .chestxray14_binary_classification import ChestXray14BinaryClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification +from .clinical_jargon_verification import ClinicalJargonVerification from .covid19_cxr_classification import COVID19CXRClassification from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 from .drug_recommendation import ( diff --git a/pyhealth/tasks/clinical_jargon_verification.py b/pyhealth/tasks/clinical_jargon_verification.py new file mode 100644 index 000000000..76ecc8657 --- /dev/null +++ b/pyhealth/tasks/clinical_jargon_verification.py @@ -0,0 +1,132 @@ +import json +from typing import Any, Dict, List + +from ..data import Patient +from .base_task import BaseTask + + +def parse_bool(value: Any) -> bool: + """Normalize released boolean-like values. + + Args: + value: A boolean or string-like value from the normalized benchmark. + + Returns: + ``True`` when the input represents a truthy flag, otherwise ``False``. + """ + if isinstance(value, bool): + return value + return str(value).strip().lower() == "true" + + +class ClinicalJargonVerification(BaseTask): + """Binary candidate-verification task for public clinical jargon benchmarks. + + Each sample pairs a benchmark question with one candidate expansion and + asks whether that candidate is correct. This reframes the released jargon + benchmark into a standard PyHealth binary classification task. + + Args: + benchmark: Benchmark subset to include. One of ``all``, ``medlingo``, + or ``casi``. + casi_variant: Which CASI view to expose. ``release62`` uses the public + release as-is, while ``paper59`` applies the local exclusion set + used to approximate the paper's filtered benchmark. + medlingo_distractors: Number of negative MedLingo candidates to retain + per example. Valid values are ``1``, ``2``, and ``3``. + + Examples: + >>> from pyhealth.datasets import ClinicalJargonDataset + >>> from pyhealth.tasks import ClinicalJargonVerification + >>> dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon") + >>> task = ClinicalJargonVerification( + ... benchmark="medlingo", + ... medlingo_distractors=1, + ... ) + >>> samples = dataset.set_task(task) + >>> print(samples[0]["label"]) + """ + + task_name: str = "ClinicalJargonVerification" + input_schema: Dict[str, str] = {"paired_text": "text"} + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__( + self, + benchmark: str = "all", + casi_variant: str = "release62", + medlingo_distractors: int = 3, + ) -> None: + """Initialize the clinical jargon verification task. + + Args: + benchmark: Benchmark subset to include. + casi_variant: CASI variant to expose. + medlingo_distractors: Number of negative MedLingo candidates. + + Raises: + ValueError: If any task configuration argument is unsupported. + """ + if benchmark not in {"all", "medlingo", "casi"}: + raise ValueError(f"Unsupported benchmark: {benchmark}") + if casi_variant not in {"release62", "paper59"}: + raise ValueError(f"Unsupported CASI variant: {casi_variant}") + if medlingo_distractors not in {1, 2, 3}: + raise ValueError("medlingo_distractors must be 1, 2, or 3") + self.benchmark = benchmark + self.casi_variant = casi_variant + self.medlingo_distractors = medlingo_distractors + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Generate verification samples for one patient record. + + Args: + patient: A PyHealth patient object containing one ``examples`` + event from the normalized clinical jargon dataset. + + Returns: + A list of binary verification samples. MedLingo records emit one + positive and ``medlingo_distractors`` negative candidates. CASI + emits all released candidates for the selected variant. + """ + events = patient.get_events(event_type="examples") + if len(events) != 1: + return [] + event = events[0] + + if self.benchmark != "all" and event.benchmark != self.benchmark: + return [] + if event.benchmark == "casi" and self.casi_variant == "paper59": + if not parse_bool(event.paper59_included): + return [] + + if event.benchmark == "casi" and self.casi_variant == "paper59": + candidates = json.loads(event.candidate_expansions_paper59_json) + else: + candidates = json.loads(event.candidate_expansions_json) + + if event.benchmark == "medlingo": + gold = event.gold_expansion + negatives = [candidate for candidate in candidates if candidate != gold] + candidates = [gold, *negatives[: self.medlingo_distractors]] + + samples: List[Dict[str, Any]] = [] + for candidate in candidates: + samples.append( + { + "patient_id": patient.patient_id, + "record_id": event.sample_id, + "sample_id": event.sample_id, + "benchmark": event.benchmark, + "abbreviation": event.abbreviation, + "candidate_expansion": candidate, + "surface_form_group": event.surface_form_group, + "paired_text": ( + f"Question: {event.question}\n" + f"Candidate expansion: {candidate}\n" + "Is this the correct expansion?" + ), + "label": int(candidate == event.gold_expansion), + } + ) + return samples diff --git a/test-resources/clinical_jargon/clinical_jargon_examples.csv b/test-resources/clinical_jargon/clinical_jargon_examples.csv new file mode 100644 index 000000000..2ccf31feb --- /dev/null +++ b/test-resources/clinical_jargon/clinical_jargon_examples.csv @@ -0,0 +1,10 @@ +sample_id,benchmark,abbreviation,context,question,question_zero_shot,gold_expansion,gold_aliases_json,surface_form_group,paper59_included,source_file,candidate_expansions_json,candidate_expansions_paper59_json +medlingo_001,medlingo,PRN,,"In a clinical note that mentions a high creat, creat stands for creatine. In a clinical note that mentions PRN, PRN stands for","In a clinical note that mentions PRN, PRN stands for",as needed,"[""as needed""]",all_caps,true,medlingo_questions.csv,"[""as needed"", ""vital signs"", ""hypertension"", ""abdomen""]","[""as needed"", ""vital signs"", ""hypertension"", ""abdomen""]" +medlingo_002,medlingo,VS,,"In a clinical note that mentions a high creat, creat stands for creatine. In a clinical note that mentions VS every four hours, VS stands for","In a clinical note that mentions VS every four hours, VS stands for",vital signs,"[""vital signs""]",all_caps,true,medlingo_questions.csv,"[""vital signs"", ""as needed"", ""hypertension"", ""abdomen""]","[""vital signs"", ""as needed"", ""hypertension"", ""abdomen""]" +medlingo_003,medlingo,HTN,,"In a clinical note that mentions a high creat, creat stands for creatine. In a clinical note that mentions uncontrolled HTN, HTN stands for","In a clinical note that mentions uncontrolled HTN, HTN stands for",hypertension,"[""hypertension""]",all_caps,true,medlingo_questions.csv,"[""hypertension"", ""as needed"", ""vital signs"", ""abdomen""]","[""hypertension"", ""as needed"", ""vital signs"", ""abdomen""]" +casi_0001,casi,AB,"Elective AB in 1989.","Elective AB in 1989. In this sentence, AB means:","Elective AB in 1989. In this sentence, AB means:",abortion,"[""abortion""]",all_caps,true,cleaned_dataset_AB.csv,"[""abortion"", ""blood group in ABO system""]","[""abortion""]" +casi_0002,casi,AB,"Type and screen: AB positive.","Type and screen: AB positive. In this sentence, AB means:","Type and screen: AB positive. In this sentence, AB means:","blood group in ABO system","[""blood group in ABO system""]",all_caps,false,cleaned_dataset_AB.csv,"[""abortion"", ""blood group in ABO system""]","[""abortion""]" +casi_0003,casi,US,"US abdomen ordered.","US abdomen ordered. In this sentence, US means:","US abdomen ordered. In this sentence, US means:",ultrasound,"[""ultrasound""]",all_caps,true,cleaned_dataset_US.csv,"[""United States"", ""ultrasound""]","[""ultrasound""]" +casi_0004,casi,US,"Travel history from the US noted.","Travel history from the US noted. In this sentence, US means:","Travel history from the US noted. In this sentence, US means:","United States","[""United States""]",all_caps,false,cleaned_dataset_US.csv,"[""United States"", ""ultrasound""]","[""ultrasound""]" +casi_0005,casi,CA,"History notable for CA of the colon.","History notable for CA of the colon. In this sentence, CA means:","History notable for CA of the colon. In this sentence, CA means:",cancer,"[""cancer""]",all_caps,true,cleaned_dataset_CA.csv,"[""cancer"", ""carbohydrate antigen""]","[""cancer"", ""carbohydrate antigen""]" +casi_0006,casi,CA,"Lab review included an elevated CA 19-9.","Lab review included an elevated CA 19-9. In this sentence, CA means:","Lab review included an elevated CA 19-9. In this sentence, CA means:","carbohydrate antigen","[""carbohydrate antigen""]",all_caps,true,cleaned_dataset_CA.csv,"[""cancer"", ""carbohydrate antigen""]","[""cancer"", ""carbohydrate antigen""]" diff --git a/tests/core/test_clinical_jargon.py b/tests/core/test_clinical_jargon.py new file mode 100644 index 000000000..4c97cd567 --- /dev/null +++ b/tests/core/test_clinical_jargon.py @@ -0,0 +1,110 @@ +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import patch + +from pyhealth.datasets import ClinicalJargonDataset, split_by_patient +from pyhealth.tasks import ClinicalJargonVerification + + +class TestClinicalJargonDataset(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.root = ( + Path(__file__).parent.parent.parent + / "test-resources" + / "clinical_jargon" + ) + cls.cache_dir = TemporaryDirectory() + + @classmethod + def tearDownClass(cls): + cls.cache_dir.cleanup() + + def make_dataset(self): + return ClinicalJargonDataset( + root=str(self.root), + cache_dir=self.cache_dir.name, + ) + + def test_dataset_initialization(self): + dataset = self.make_dataset() + self.assertEqual(dataset.dataset_name, "clinical_jargon") + + def test_missing_csv_requires_explicit_download(self): + with TemporaryDirectory() as root: + with self.assertRaises(FileNotFoundError): + ClinicalJargonDataset(root=root, cache_dir=self.cache_dir.name) + + def test_default_task(self): + dataset = self.make_dataset() + self.assertIsInstance(dataset.default_task, ClinicalJargonVerification) + + def test_num_patients(self): + dataset = self.make_dataset() + self.assertEqual(len(dataset.unique_patient_ids), 9) + + def test_get_patient(self): + dataset = self.make_dataset() + patient = dataset.get_patient(dataset.unique_patient_ids[0]) + self.assertIsNotNone(patient) + self.assertEqual(len(patient.get_events("examples")), 1) + + def test_release62_task_samples(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification(benchmark="all", casi_variant="release62") + samples = dataset.set_task(task) + self.assertEqual(len(samples), 24) + + def test_paper59_task_samples(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification(benchmark="all", casi_variant="paper59") + samples = dataset.set_task(task) + self.assertEqual(len(samples), 18) + + def test_medlingo_distractor_control(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification( + benchmark="medlingo", + medlingo_distractors=1, + ) + samples = dataset.set_task(task) + self.assertEqual(len(samples), 6) + + def test_split_by_patient_keeps_candidate_rows_together(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification( + benchmark="medlingo", + medlingo_distractors=1, + ) + samples = dataset.set_task(task) + splits = split_by_patient(samples, [0.6, 0.2, 0.2], seed=42) + split_patient_ids = [ + {split_dataset[index]["patient_id"] for index in range(len(split_dataset))} + for split_dataset in splits + ] + for left_index, left_ids in enumerate(split_patient_ids): + for right_ids in split_patient_ids[left_index + 1 :]: + self.assertTrue(left_ids.isdisjoint(right_ids)) + + def test_task_output_shape(self): + dataset = self.make_dataset() + task = ClinicalJargonVerification(benchmark="casi", casi_variant="paper59") + samples = dataset.set_task(task) + sample = samples[0] + self.assertIn("paired_text", sample) + self.assertIn("label", sample) + self.assertIn("candidate_expansion", sample) + + def test_rejects_non_local_casi_cache_names(self): + with patch.object( + ClinicalJargonDataset, + "_download_text", + return_value='[{"name":"../escape.csv","download_url":"https://example.com"}]', + ): + with self.assertRaises(ValueError): + ClinicalJargonDataset._fetch_casi_rows(Path(self.cache_dir.name)) + + +if __name__ == "__main__": + unittest.main()