Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c29d443
Fixed repo to be able to run TUEV/TUAB + updated example scripts
lehendo Mar 15, 2026
7f1f7c3
Args need to be passed correctly
lehendo Mar 15, 2026
539f91e
Minor fixes and precomputed STFT logic
lehendo Mar 15, 2026
97095de
Fix the test files to reflect codebase changes
lehendo Mar 15, 2026
4ce7e81
Args update
lehendo Mar 15, 2026
33ab326
Merge branch 'sunlabuiuc:master' into arjunfixes1
lehendo Mar 16, 2026
5fd13a3
test script fixes
lehendo Mar 16, 2026
29dd560
Merge branch 'sunlabuiuc:master' into arjunfixes1
lehendo Mar 16, 2026
4a192b2
dataset path update
lehendo Mar 16, 2026
dad5772
fix contrawr - small change
lehendo Mar 16, 2026
6f852e6
divide by 0 error
lehendo Mar 16, 2026
f428a5c
Incorporate tfm logic
lehendo Mar 16, 2026
bf0fac6
Fix label stuff
lehendo Mar 16, 2026
ea3b23b
tuab fixes
lehendo Mar 17, 2026
9a1119b
fix metrics
lehendo Mar 17, 2026
fa94d7a
aggregate alphas
lehendo Mar 17, 2026
ce591d9
Fix splitting and add tfm weights
lehendo Mar 18, 2026
0ea5a04
fix tfm+tuab
lehendo Mar 19, 2026
9d06098
updates scripts and haoyu splitter
lehendo Mar 23, 2026
72bcdaf
fix conflict
lehendo Mar 23, 2026
dae809a
Merge branch 'master' into arjunfixes1
lehendo Mar 23, 2026
1a17491
Remove weightfiles from tracking and add to .gitignore
lehendo Mar 23, 2026
8ca6010
Merge branch 'sunlabuiuc:master' into arjunfixes1
lehendo Mar 25, 2026
e39f01c
Merge branch 'sunlabuiuc:master' into arjunfixes1
lehendo Mar 29, 2026
eb53b17
normalization = 95%
lehendo Mar 29, 2026
2491fcf
temporarily re-add weight files
lehendo Mar 29, 2026
38f7168
16 workers
lehendo Mar 29, 2026
748e984
tuab sanity check
lehendo Mar 29, 2026
0af1e0a
consistent log outputs
lehendo Mar 30, 2026
67fda60
Merge branch 'sunlabuiuc:master' into arjunfixes1
lehendo Mar 30, 2026
1fc1ff3
test tuab
lehendo Mar 30, 2026
588b8b6
Merge branch 'sunlabuiuc:master' into arjunfixes1
lehendo Apr 3, 2026
1687e36
change back to multiclass
lehendo Apr 3, 2026
edddd23
update conformal scripts
lehendo Apr 3, 2026
e4a778b
remove weightfiles
lehendo Apr 4, 2026
01bca8b
oops
lehendo Apr 4, 2026
1a88362
fix tests
lehendo Apr 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions examples/conformal_eeg/test_tfm_tuab_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
Quick inference test: TFMTokenizer on TUAB using local weightfiles/.

Two weight setups (ask your PI which matches their training):

1) Default (matches conformal example scripts):
- tokenizer: weightfiles/tfm_tokenizer_last.pth (multi-dataset tokenizer)
- classifier: weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB/.../best_model.pth

2) PI benchmark TUAB-specific files (place in weightfiles/):
- tokenizer: tfm_tokenizer_tuab.pth
- classifier: tfm_encoder_best_model_tuab.pth
Use: --pi-tuab-weights

Split modes:
- conformal (default): same test set as conformal runs (TUH eval via patient conformal split).
- pi_benchmark: train/val ratio [0.875, 0.125] on train partition; test = TUH eval (same patients as official eval).

Usage:
python examples/conformal_eeg/test_tfm_tuab_inference.py
python examples/conformal_eeg/test_tfm_tuab_inference.py --pi-tuab-weights
python examples/conformal_eeg/test_tfm_tuab_inference.py \\
--tuab-pi-weights-dir /shared/eng/conformal_eeg --split pi_benchmark
python examples/conformal_eeg/test_tfm_tuab_inference.py --tokenizer-weights PATH --classifier-weights PATH
"""

import argparse
import os
import time

import torch

from pyhealth.datasets import (
TUABDataset,
get_dataloader,
split_by_patient_conformal_tuh,
split_by_patient_tuh,
)
from pyhealth.models import TFMTokenizer
from pyhealth.tasks import EEGAbnormalTUAB
from pyhealth.trainer import Trainer

REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
WEIGHTFILES = os.path.join(REPO_ROOT, "weightfiles")
DEFAULT_TOKENIZER = os.path.join(WEIGHTFILES, "tfm_tokenizer_last.pth")
CLASSIFIER_WEIGHTS_DIR = os.path.join(
WEIGHTFILES, "TFM_Tokenizer_multiple_finetuned_on_TUAB"
)
PI_TOKENIZER = os.path.join(WEIGHTFILES, "tfm_tokenizer_tuab.pth")
PI_CLASSIFIER = os.path.join(WEIGHTFILES, "tfm_encoder_best_model_tuab.pth")


def main():
parser = argparse.ArgumentParser(description="TFM TUAB inference sanity check")
parser.add_argument(
"--root",
type=str,
default="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf",
help="Path to TUAB edf/ directory.",
)
parser.add_argument("--gpu_id", type=int, default=0)
parser.add_argument(
"--seed",
type=int,
default=1,
choices=[1, 2, 3, 4, 5],
help="Which fine-tuned classifier folder _1.._5 (only if not using --classifier-weights).",
)
parser.add_argument(
"--pi-tuab-weights",
action="store_true",
help="Use PI TUAB-specific files under weightfiles/: tfm_tokenizer_tuab.pth, "
"tfm_encoder_best_model_tuab.pth",
)
parser.add_argument(
"--tuab-pi-weights-dir",
type=str,
default=None,
metavar="DIR",
help="Directory containing PI's TUAB TFM files (e.g. /shared/eng/conformal_eeg). "
"Loads tfm_tokenizer_tuab.pth + tfm_encoder_best_model_tuab.pth from there. "
"Overrides --pi-tuab-weights and default weightfiles paths unless "
"--tokenizer-weights / --classifier-weights are set explicitly.",
)
parser.add_argument(
"--tokenizer-weights",
type=str,
default=None,
help="Override tokenizer checkpoint path.",
)
parser.add_argument(
"--classifier-weights",
type=str,
default=None,
help="Override classifier checkpoint path (single .pth file).",
)
parser.add_argument(
"--split",
type=str,
choices=["conformal", "pi_benchmark"],
default="conformal",
help="conformal: same as EEG conformal scripts; pi_benchmark: 0.875/0.125 train/val on train partition.",
)
parser.add_argument(
"--split-seed",
type=int,
default=42,
help="RNG seed for patient shuffle (pi_benchmark and conformal).",
)
args = parser.parse_args()
device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu"

if args.tuab_pi_weights_dir is not None:
d = os.path.expanduser(args.tuab_pi_weights_dir)
tok = os.path.join(d, "tfm_tokenizer_tuab.pth")
cls_path = os.path.join(d, "tfm_encoder_best_model_tuab.pth")
elif args.pi_tuab_weights:
tok = PI_TOKENIZER
cls_path = PI_CLASSIFIER
else:
tok = DEFAULT_TOKENIZER
cls_path = os.path.join(
CLASSIFIER_WEIGHTS_DIR,
f"TFM_Tokenizer_multiple_finetuned_on_TUAB_{args.seed}",
"best_model.pth",
)

if args.tokenizer_weights is not None:
tok = args.tokenizer_weights
if args.classifier_weights is not None:
cls_path = args.classifier_weights

print(f"Device: {device}")
print(f"TUAB root: {args.root}")
print(f"Split mode: {args.split}")
print(f"Tokenizer weights: {tok}")
print(f"Classifier weights: {cls_path}")

t0 = time.time()
base_dataset = TUABDataset(root=args.root, subset="both")
print(f"Dataset loaded in {time.time() - t0:.1f}s")

t0 = time.time()
sample_dataset = base_dataset.set_task(
EEGAbnormalTUAB(
resample_rate=200,
normalization="95th_percentile",
compute_stft=True,
),
num_workers=16,
)
print(f"Task set in {time.time() - t0:.1f}s | total samples: {len(sample_dataset)}")

if args.split == "conformal":
_, _, _, test_ds = split_by_patient_conformal_tuh(
dataset=sample_dataset,
ratios=[0.6, 0.2, 0.2],
seed=args.split_seed,
)
else:
_, _, test_ds = split_by_patient_tuh(
sample_dataset,
[0.875, 0.125],
seed=args.split_seed,
)

test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
print(f"Test set size: {len(test_ds)}")

model = TFMTokenizer(dataset=sample_dataset).to(device)
model.load_pretrained_weights(
tokenizer_checkpoint_path=tok,
classifier_checkpoint_path=cls_path,
)

trainer = Trainer(
model=model,
device=device,
metrics=[
"accuracy",
"balanced_accuracy",
"f1_weighted",
"f1_macro",
"roc_auc_weighted_ovr",
],
enable_logging=False,
)
t0 = time.time()
results = trainer.evaluate(test_loader)
print(f"\nEval time: {time.time() - t0:.1f}s")
print("\n=== Test Results ===")
for metric, value in results.items():
print(f" {metric}: {value:.4f}")


if __name__ == "__main__":
main()
112 changes: 112 additions & 0 deletions examples/conformal_eeg/test_tfm_tuev_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Quick inference test: TFMTokenizer on TUEV using local weightfiles/.

Mirrors the PI's benchmark script but uses the weightfiles/ paths already
present in this repo. No training — pure inference to verify weights and
normalization are correct.

Usage:
python examples/conformal_eeg/test_tfm_tuev_inference.py
python examples/conformal_eeg/test_tfm_tuev_inference.py --gpu_id 1
python examples/conformal_eeg/test_tfm_tuev_inference.py --seed 2 # use _2/best_model.pth
"""

import argparse
import os
import time

import torch

from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_patient_conformal_tuh
from pyhealth.models import TFMTokenizer
from pyhealth.tasks import EEGEventsTUEV
from pyhealth.trainer import Trainer

TUEV_ROOT = "/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/"

REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
TOKENIZER_WEIGHTS = os.path.join(REPO_ROOT, "weightfiles", "tfm_tokenizer_last.pth")
CLASSIFIER_WEIGHTS_DIR = os.path.join(
REPO_ROOT, "weightfiles", "TFM_Tokenizer_multiple_finetuned_on_TUEV"
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--gpu_id", type=int, default=0)
parser.add_argument(
"--seed", type=int, default=1, choices=[1, 2, 3, 4, 5],
help="Which fine-tuned classifier to use (1-5)."
)
args = parser.parse_args()
device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu"

classifier_weights = os.path.join(
CLASSIFIER_WEIGHTS_DIR,
f"TFM_Tokenizer_multiple_finetuned_on_TUEV_{args.seed}",
"best_model.pth",
)

print(f"Device: {device}")
print(f"Tokenizer weights: {TOKENIZER_WEIGHTS}")
print(f"Classifier weights: {classifier_weights}")

# ------------------------------------------------------------------ #
# STEP 1: Load dataset
# ------------------------------------------------------------------ #
t0 = time.time()
base_dataset = TUEVDataset(root=TUEV_ROOT, subset="both")
print(f"Dataset loaded in {time.time() - t0:.1f}s")

# ------------------------------------------------------------------ #
# STEP 2: Set task — normalization="95th_percentile" matches training
# ------------------------------------------------------------------ #
t0 = time.time()
sample_dataset = base_dataset.set_task(
EEGEventsTUEV(
resample_rate=200,
normalization="95th_percentile",
compute_stft=True,
)
)
print(f"Task set in {time.time() - t0:.1f}s | total samples: {len(sample_dataset)}")

# ------------------------------------------------------------------ #
# STEP 3: Extract fixed test set (TUH eval partition)
# ------------------------------------------------------------------ #
_, _, _, test_ds = split_by_patient_conformal_tuh(
dataset=sample_dataset,
ratios=[0.6, 0.2, 0.2],
seed=42,
)
test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
print(f"Test set size: {len(test_ds)}")

# ------------------------------------------------------------------ #
# STEP 4: Load TFMTokenizer with pre-trained weights (no training)
# ------------------------------------------------------------------ #
model = TFMTokenizer(dataset=sample_dataset).to(device)
model.load_pretrained_weights(
tokenizer_checkpoint_path=TOKENIZER_WEIGHTS,
classifier_checkpoint_path=classifier_weights,
)

# ------------------------------------------------------------------ #
# STEP 5: Evaluate
# ------------------------------------------------------------------ #
trainer = Trainer(
model=model,
device=device,
metrics=["accuracy", "f1_weighted", "f1_macro"],
enable_logging=False,
)
t0 = time.time()
results = trainer.evaluate(test_loader)
print(f"\nEval time: {time.time() - t0:.1f}s")
print("\n=== Test Results ===")
for metric, value in results.items():
print(f" {metric}: {value:.4f}")


if __name__ == "__main__":
main()
32 changes: 23 additions & 9 deletions examples/conformal_eeg/tuab_conventional_conformal.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,16 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--weights-dir",
type=str,
default="weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB",
help="Root folder of fine-tuned TFM classifier checkpoints (only with --model tfm).",
default="/shared/eng/conformal_eeg",
help="Root folder of TFM classifier checkpoints (only with --model tfm). "
"If the directory contains tfm_encoder_best_model_tuab.pth directly, "
"that single checkpoint is used for all seeds (PI TUAB setup). "
"Otherwise expects per-seed subdirs {base}_1..N/best_model.pth.",
)
parser.add_argument(
"--tokenizer-weights",
type=str,
default="weightfiles/tfm_tokenizer_last.pth",
default="/shared/eng/conformal_eeg/tfm_tokenizer_tuab.pth",
help="Path to the pre-trained TFM tokenizer weights (only with --model tfm).",
)
parser.add_argument(
Expand All @@ -152,9 +155,19 @@ def _do_split(dataset, ratios, seed, split_type):


def _load_tfm_weights(model, args, run_idx: int) -> None:
"""Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based)."""
base = os.path.basename(args.weights_dir)
classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth")
"""Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based).

Supports two layouts:
- Single classifier (PI TUAB setup): weights_dir/tfm_encoder_best_model_tuab.pth
Used for all seeds — only the data split varies across runs.
- Per-seed subdirs: weights_dir/{base}_{run_idx+1}/best_model.pth
"""
single = os.path.join(args.weights_dir, "tfm_encoder_best_model_tuab.pth")
if os.path.isfile(single):
classifier_path = single
else:
base = os.path.basename(args.weights_dir)
classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth")
print(f" Loading TFM weights (run {run_idx + 1}): {classifier_path}")
model.load_pretrained_weights(
tokenizer_checkpoint_path=args.tokenizer_weights,
Expand Down Expand Up @@ -283,7 +296,7 @@ def _print_multi_seed_summary(
n_runs = len(all_metrics)

print("\n" + "=" * 80)
print("Per-run LABEL results (fixed test set = TUH eval partition)")
print(f"Per-run results — alpha={alpha} (LABEL, fixed test set = TUH eval partition)")
print("=" * 80)
print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'ROC-AUC':<10} {'F1':<8} "
f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}")
Expand All @@ -295,7 +308,8 @@ def _print_multi_seed_summary(
f"{m['miscoverage']:<12.4f} {m['avg_set_size']:<12.2f}")

print("\n" + "=" * 80)
print(f"LABEL summary (mean \u00b1 std over {n_runs} runs, fixed test set)")
print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)")
print(" Method: LABEL")
print("=" * 80)
print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}")
print(f" ROC-AUC: {roc_aucs.mean():.4f} \u00b1 {roc_aucs.std():.4f}")
Expand Down Expand Up @@ -334,7 +348,7 @@ def _main(args: argparse.Namespace) -> None:
print("STEP 1: Load TUAB + build task dataset (shared across all seeds)")
print("=" * 80)
dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test)
sample_dataset = dataset.set_task(EEGAbnormalTUAB())
sample_dataset = dataset.set_task(EEGAbnormalTUAB(normalization="95th_percentile"), num_workers=16)
if args.quick_test and len(sample_dataset) > quick_test_max_samples:
sample_dataset = sample_dataset.subset(range(quick_test_max_samples))
print(f"Capped to {quick_test_max_samples} samples for quick-test.")
Expand Down
Loading
Loading