Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ Available Datasets
datasets/pyhealth.datasets.ChestXray14Dataset
datasets/pyhealth.datasets.TUABDataset
datasets/pyhealth.datasets.TUEVDataset
datasets/pyhealth.datasets.ClinicalJargonDataset
datasets/pyhealth.datasets.ClinVarDataset
datasets/pyhealth.datasets.COSMICDataset
datasets/pyhealth.datasets.TCGAPRADDataset
Expand Down
21 changes: 21 additions & 0 deletions docs/api/datasets/pyhealth.datasets.ClinicalJargonDataset.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
pyhealth.datasets.ClinicalJargonDataset
=======================================

Public clinical jargon benchmark dataset backed by the released MedLingo and
CASI assets from Jia et al. (CHIL 2025).

Example
-------

.. code-block:: python

from pyhealth.datasets import ClinicalJargonDataset

dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon", download=True)
task = dataset.default_task
samples = dataset.set_task(task)

.. autoclass:: pyhealth.datasets.ClinicalJargonDataset
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ Available Tasks
Benchmark EHRShot <tasks/pyhealth.tasks.benchmark_ehrshot>
ChestX-ray14 Binary Classification <tasks/pyhealth.tasks.ChestXray14BinaryClassification>
ChestX-ray14 Multilabel Classification <tasks/pyhealth.tasks.ChestXray14MultilabelClassification>
Clinical Jargon Verification <tasks/pyhealth.tasks.ClinicalJargonVerification>
Variant Classification (ClinVar) <tasks/pyhealth.tasks.VariantClassificationClinVar>
Mutation Pathogenicity (COSMIC) <tasks/pyhealth.tasks.MutationPathogenicityPrediction>
Cancer Survival Prediction (TCGA) <tasks/pyhealth.tasks.CancerSurvivalPrediction>
Expand Down
25 changes: 25 additions & 0 deletions docs/api/tasks/pyhealth.tasks.ClinicalJargonVerification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
pyhealth.tasks.ClinicalJargonVerification
=========================================

Binary candidate-verification task for the public MedLingo and CASI clinical
jargon benchmarks.

Example
-------

.. code-block:: python

from pyhealth.datasets import ClinicalJargonDataset
from pyhealth.tasks import ClinicalJargonVerification

dataset = ClinicalJargonDataset(root="/tmp/clinical_jargon", download=True)
task = ClinicalJargonVerification(
benchmark="medlingo",
medlingo_distractors=1,
)
samples = dataset.set_task(task)

.. autoclass:: pyhealth.tasks.ClinicalJargonVerification
:members:
:undoc-members:
:show-inheritance:
130 changes: 130 additions & 0 deletions examples/clinical_jargon_clinical_jargon_verification_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Clinical jargon verification example with `TransformersModel`.

This example is the course-facing ablation entrypoint for the public clinical
jargon benchmark contribution. It supports three lightweight task ablations
while keeping each source benchmark item in a single split:

- switch between ``medlingo``, ``casi``, and ``all`` benchmark subsets
- switch CASI between ``release62`` and ``paper59`` variants
- vary MedLingo distractor count through ``--medlingo-distractors``

Run this example from an environment where PyHealth is installed, such as
``pip install -e .`` from the repository root.

Example commands:
python3 examples/clinical_jargon_clinical_jargon_verification_transformers.py \
--benchmark medlingo --medlingo-distractors 1 --epochs 1

python3 examples/clinical_jargon_clinical_jargon_verification_transformers.py \
--benchmark casi --casi-variant paper59 --epochs 1
"""

import argparse
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[1]

from pyhealth.datasets import ClinicalJargonDataset, get_dataloader, split_by_patient
from pyhealth.models.transformers_model import TransformersModel
from pyhealth.tasks import ClinicalJargonVerification
from pyhealth.trainer import Trainer


def parse_args() -> argparse.Namespace:
"""Parse CLI arguments for the clinical jargon example.

Returns:
Parsed command-line arguments for dataset root, model choice, task
configuration, and training hyperparameters.
"""
parser = argparse.ArgumentParser(
description="Clinical jargon verification example with TransformersModel.",
epilog=(
"Ablation knobs: --benchmark changes the benchmark subset, "
"--casi-variant changes the CASI filtering mode, and "
"--medlingo-distractors changes the number of negative MedLingo "
"candidates."
),
)
parser.add_argument(
"--root",
type=str,
default=str(
PROJECT_ROOT / "test-resources" / "clinical_jargon"
),
help="Dataset root containing clinical_jargon_examples.csv.",
)
parser.add_argument(
"--download",
action="store_true",
help="Fetch the public source assets when the normalized CSV is missing.",
)
parser.add_argument(
"--model-name",
type=str,
default="emilyalsentzer/Bio_ClinicalBERT",
help="Hugging Face model name used by TransformersModel.",
)
parser.add_argument(
"--benchmark",
choices=["all", "medlingo", "casi"],
default="medlingo",
help="Benchmark subset used for the run.",
)
parser.add_argument(
"--casi-variant",
choices=["release62", "paper59"],
default="release62",
help="CASI filtering mode when --benchmark includes CASI.",
)
parser.add_argument(
"--medlingo-distractors",
type=int,
default=3,
help="Number of negative MedLingo candidates retained per sample.",
)
parser.add_argument("--epochs", type=int, default=1, help="Training epochs.")
parser.add_argument("--batch-size", type=int, default=4, help="Batch size.")
return parser.parse_args()


def main() -> None:
"""Run one clinical jargon verification configuration."""
args = parse_args()
dataset = ClinicalJargonDataset(root=args.root, download=args.download)
task = ClinicalJargonVerification(
benchmark=args.benchmark,
casi_variant=args.casi_variant,
medlingo_distractors=args.medlingo_distractors,
)
samples = dataset.set_task(task)
print(
{
"benchmark": args.benchmark,
"casi_variant": args.casi_variant,
"medlingo_distractors": args.medlingo_distractors,
"num_samples": len(samples),
"model_name": args.model_name,
}
)
train_dataset, val_dataset, test_dataset = split_by_patient(
samples, [0.6, 0.2, 0.2], seed=42
)
train_loader = get_dataloader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=args.batch_size, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=args.batch_size, shuffle=False)

model = TransformersModel(dataset=samples, model_name=args.model_name)
trainer = Trainer(model=model, enable_logging=False)
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
test_dataloader=test_loader,
epochs=args.epochs,
)
scores = trainer.evaluate(test_loader)
print(scores)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, *args, **kwargs):
from .shhs import SHHSDataset
from .sleepedf import SleepEDFDataset
from .bmd_hs import BMDHSDataset
from .clinical_jargon import ClinicalJargonDataset
from .support2 import Support2Dataset
from .tcga_prad import TCGAPRADDataset
from .splitter import (
Expand Down
Loading