diff --git a/examples/conformal_eeg/test_tfm_tuab_inference.py b/examples/conformal_eeg/test_tfm_tuab_inference.py new file mode 100644 index 000000000..400518469 --- /dev/null +++ b/examples/conformal_eeg/test_tfm_tuab_inference.py @@ -0,0 +1,197 @@ +""" +Quick inference test: TFMTokenizer on TUAB using local weightfiles/. + +Two weight setups (ask your PI which matches their training): + + 1) Default (matches conformal example scripts): + - tokenizer: weightfiles/tfm_tokenizer_last.pth (multi-dataset tokenizer) + - classifier: weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB/.../best_model.pth + + 2) PI benchmark TUAB-specific files (place in weightfiles/): + - tokenizer: tfm_tokenizer_tuab.pth + - classifier: tfm_encoder_best_model_tuab.pth + Use: --pi-tuab-weights + +Split modes: + - conformal (default): same test set as conformal runs (TUH eval via patient conformal split). + - pi_benchmark: train/val ratio [0.875, 0.125] on train partition; test = TUH eval (same patients as official eval). + +Usage: + python examples/conformal_eeg/test_tfm_tuab_inference.py + python examples/conformal_eeg/test_tfm_tuab_inference.py --pi-tuab-weights + python examples/conformal_eeg/test_tfm_tuab_inference.py \\ + --tuab-pi-weights-dir /shared/eng/conformal_eeg --split pi_benchmark + python examples/conformal_eeg/test_tfm_tuab_inference.py --tokenizer-weights PATH --classifier-weights PATH +""" + +import argparse +import os +import time + +import torch + +from pyhealth.datasets import ( + TUABDataset, + get_dataloader, + split_by_patient_conformal_tuh, + split_by_patient_tuh, +) +from pyhealth.models import TFMTokenizer +from pyhealth.tasks import EEGAbnormalTUAB +from pyhealth.trainer import Trainer + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +WEIGHTFILES = os.path.join(REPO_ROOT, "weightfiles") +DEFAULT_TOKENIZER = os.path.join(WEIGHTFILES, "tfm_tokenizer_last.pth") +CLASSIFIER_WEIGHTS_DIR = os.path.join( + WEIGHTFILES, "TFM_Tokenizer_multiple_finetuned_on_TUAB" +) +PI_TOKENIZER = os.path.join(WEIGHTFILES, "tfm_tokenizer_tuab.pth") +PI_CLASSIFIER = os.path.join(WEIGHTFILES, "tfm_encoder_best_model_tuab.pth") + + +def main(): + parser = argparse.ArgumentParser(description="TFM TUAB inference sanity check") + parser.add_argument( + "--root", + type=str, + default="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf", + help="Path to TUAB edf/ directory.", + ) + parser.add_argument("--gpu_id", type=int, default=0) + parser.add_argument( + "--seed", + type=int, + default=1, + choices=[1, 2, 3, 4, 5], + help="Which fine-tuned classifier folder _1.._5 (only if not using --classifier-weights).", + ) + parser.add_argument( + "--pi-tuab-weights", + action="store_true", + help="Use PI TUAB-specific files under weightfiles/: tfm_tokenizer_tuab.pth, " + "tfm_encoder_best_model_tuab.pth", + ) + parser.add_argument( + "--tuab-pi-weights-dir", + type=str, + default=None, + metavar="DIR", + help="Directory containing PI's TUAB TFM files (e.g. /shared/eng/conformal_eeg). " + "Loads tfm_tokenizer_tuab.pth + tfm_encoder_best_model_tuab.pth from there. " + "Overrides --pi-tuab-weights and default weightfiles paths unless " + "--tokenizer-weights / --classifier-weights are set explicitly.", + ) + parser.add_argument( + "--tokenizer-weights", + type=str, + default=None, + help="Override tokenizer checkpoint path.", + ) + parser.add_argument( + "--classifier-weights", + type=str, + default=None, + help="Override classifier checkpoint path (single .pth file).", + ) + parser.add_argument( + "--split", + type=str, + choices=["conformal", "pi_benchmark"], + default="conformal", + help="conformal: same as EEG conformal scripts; pi_benchmark: 0.875/0.125 train/val on train partition.", + ) + parser.add_argument( + "--split-seed", + type=int, + default=42, + help="RNG seed for patient shuffle (pi_benchmark and conformal).", + ) + args = parser.parse_args() + device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu" + + if args.tuab_pi_weights_dir is not None: + d = os.path.expanduser(args.tuab_pi_weights_dir) + tok = os.path.join(d, "tfm_tokenizer_tuab.pth") + cls_path = os.path.join(d, "tfm_encoder_best_model_tuab.pth") + elif args.pi_tuab_weights: + tok = PI_TOKENIZER + cls_path = PI_CLASSIFIER + else: + tok = DEFAULT_TOKENIZER + cls_path = os.path.join( + CLASSIFIER_WEIGHTS_DIR, + f"TFM_Tokenizer_multiple_finetuned_on_TUAB_{args.seed}", + "best_model.pth", + ) + + if args.tokenizer_weights is not None: + tok = args.tokenizer_weights + if args.classifier_weights is not None: + cls_path = args.classifier_weights + + print(f"Device: {device}") + print(f"TUAB root: {args.root}") + print(f"Split mode: {args.split}") + print(f"Tokenizer weights: {tok}") + print(f"Classifier weights: {cls_path}") + + t0 = time.time() + base_dataset = TUABDataset(root=args.root, subset="both") + print(f"Dataset loaded in {time.time() - t0:.1f}s") + + t0 = time.time() + sample_dataset = base_dataset.set_task( + EEGAbnormalTUAB( + resample_rate=200, + normalization="95th_percentile", + compute_stft=True, + ), + num_workers=16, + ) + print(f"Task set in {time.time() - t0:.1f}s | total samples: {len(sample_dataset)}") + + if args.split == "conformal": + _, _, _, test_ds = split_by_patient_conformal_tuh( + dataset=sample_dataset, + ratios=[0.6, 0.2, 0.2], + seed=args.split_seed, + ) + else: + _, _, test_ds = split_by_patient_tuh( + sample_dataset, + [0.875, 0.125], + seed=args.split_seed, + ) + + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + print(f"Test set size: {len(test_ds)}") + + model = TFMTokenizer(dataset=sample_dataset).to(device) + model.load_pretrained_weights( + tokenizer_checkpoint_path=tok, + classifier_checkpoint_path=cls_path, + ) + + trainer = Trainer( + model=model, + device=device, + metrics=[ + "accuracy", + "balanced_accuracy", + "f1_weighted", + "f1_macro", + "roc_auc_weighted_ovr", + ], + enable_logging=False, + ) + t0 = time.time() + results = trainer.evaluate(test_loader) + print(f"\nEval time: {time.time() - t0:.1f}s") + print("\n=== Test Results ===") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + +if __name__ == "__main__": + main() diff --git a/examples/conformal_eeg/test_tfm_tuev_inference.py b/examples/conformal_eeg/test_tfm_tuev_inference.py new file mode 100644 index 000000000..fede48d4d --- /dev/null +++ b/examples/conformal_eeg/test_tfm_tuev_inference.py @@ -0,0 +1,112 @@ +""" +Quick inference test: TFMTokenizer on TUEV using local weightfiles/. + +Mirrors the PI's benchmark script but uses the weightfiles/ paths already +present in this repo. No training — pure inference to verify weights and +normalization are correct. + +Usage: + python examples/conformal_eeg/test_tfm_tuev_inference.py + python examples/conformal_eeg/test_tfm_tuev_inference.py --gpu_id 1 + python examples/conformal_eeg/test_tfm_tuev_inference.py --seed 2 # use _2/best_model.pth +""" + +import argparse +import os +import time + +import torch + +from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_patient_conformal_tuh +from pyhealth.models import TFMTokenizer +from pyhealth.tasks import EEGEventsTUEV +from pyhealth.trainer import Trainer + +TUEV_ROOT = "/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/" + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +TOKENIZER_WEIGHTS = os.path.join(REPO_ROOT, "weightfiles", "tfm_tokenizer_last.pth") +CLASSIFIER_WEIGHTS_DIR = os.path.join( + REPO_ROOT, "weightfiles", "TFM_Tokenizer_multiple_finetuned_on_TUEV" +) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--gpu_id", type=int, default=0) + parser.add_argument( + "--seed", type=int, default=1, choices=[1, 2, 3, 4, 5], + help="Which fine-tuned classifier to use (1-5)." + ) + args = parser.parse_args() + device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu" + + classifier_weights = os.path.join( + CLASSIFIER_WEIGHTS_DIR, + f"TFM_Tokenizer_multiple_finetuned_on_TUEV_{args.seed}", + "best_model.pth", + ) + + print(f"Device: {device}") + print(f"Tokenizer weights: {TOKENIZER_WEIGHTS}") + print(f"Classifier weights: {classifier_weights}") + + # ------------------------------------------------------------------ # + # STEP 1: Load dataset + # ------------------------------------------------------------------ # + t0 = time.time() + base_dataset = TUEVDataset(root=TUEV_ROOT, subset="both") + print(f"Dataset loaded in {time.time() - t0:.1f}s") + + # ------------------------------------------------------------------ # + # STEP 2: Set task — normalization="95th_percentile" matches training + # ------------------------------------------------------------------ # + t0 = time.time() + sample_dataset = base_dataset.set_task( + EEGEventsTUEV( + resample_rate=200, + normalization="95th_percentile", + compute_stft=True, + ) + ) + print(f"Task set in {time.time() - t0:.1f}s | total samples: {len(sample_dataset)}") + + # ------------------------------------------------------------------ # + # STEP 3: Extract fixed test set (TUH eval partition) + # ------------------------------------------------------------------ # + _, _, _, test_ds = split_by_patient_conformal_tuh( + dataset=sample_dataset, + ratios=[0.6, 0.2, 0.2], + seed=42, + ) + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + print(f"Test set size: {len(test_ds)}") + + # ------------------------------------------------------------------ # + # STEP 4: Load TFMTokenizer with pre-trained weights (no training) + # ------------------------------------------------------------------ # + model = TFMTokenizer(dataset=sample_dataset).to(device) + model.load_pretrained_weights( + tokenizer_checkpoint_path=TOKENIZER_WEIGHTS, + classifier_checkpoint_path=classifier_weights, + ) + + # ------------------------------------------------------------------ # + # STEP 5: Evaluate + # ------------------------------------------------------------------ # + trainer = Trainer( + model=model, + device=device, + metrics=["accuracy", "f1_weighted", "f1_macro"], + enable_logging=False, + ) + t0 = time.time() + results = trainer.evaluate(test_loader) + print(f"\nEval time: {time.time() - t0:.1f}s") + print("\n=== Test Results ===") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + +if __name__ == "__main__": + main() diff --git a/examples/conformal_eeg/tuab_conventional_conformal.py b/examples/conformal_eeg/tuab_conventional_conformal.py index 0f6b32401..7789b83e7 100644 --- a/examples/conformal_eeg/tuab_conventional_conformal.py +++ b/examples/conformal_eeg/tuab_conventional_conformal.py @@ -123,13 +123,16 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--weights-dir", type=str, - default="weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB", - help="Root folder of fine-tuned TFM classifier checkpoints (only with --model tfm).", + default="/shared/eng/conformal_eeg", + help="Root folder of TFM classifier checkpoints (only with --model tfm). " + "If the directory contains tfm_encoder_best_model_tuab.pth directly, " + "that single checkpoint is used for all seeds (PI TUAB setup). " + "Otherwise expects per-seed subdirs {base}_1..N/best_model.pth.", ) parser.add_argument( "--tokenizer-weights", type=str, - default="weightfiles/tfm_tokenizer_last.pth", + default="/shared/eng/conformal_eeg/tfm_tokenizer_tuab.pth", help="Path to the pre-trained TFM tokenizer weights (only with --model tfm).", ) parser.add_argument( @@ -152,9 +155,19 @@ def _do_split(dataset, ratios, seed, split_type): def _load_tfm_weights(model, args, run_idx: int) -> None: - """Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based).""" - base = os.path.basename(args.weights_dir) - classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth") + """Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based). + + Supports two layouts: + - Single classifier (PI TUAB setup): weights_dir/tfm_encoder_best_model_tuab.pth + Used for all seeds — only the data split varies across runs. + - Per-seed subdirs: weights_dir/{base}_{run_idx+1}/best_model.pth + """ + single = os.path.join(args.weights_dir, "tfm_encoder_best_model_tuab.pth") + if os.path.isfile(single): + classifier_path = single + else: + base = os.path.basename(args.weights_dir) + classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth") print(f" Loading TFM weights (run {run_idx + 1}): {classifier_path}") model.load_pretrained_weights( tokenizer_checkpoint_path=args.tokenizer_weights, @@ -283,7 +296,7 @@ def _print_multi_seed_summary( n_runs = len(all_metrics) print("\n" + "=" * 80) - print("Per-run LABEL results (fixed test set = TUH eval partition)") + print(f"Per-run results — alpha={alpha} (LABEL, fixed test set = TUH eval partition)") print("=" * 80) print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'ROC-AUC':<10} {'F1':<8} " f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") @@ -295,7 +308,8 @@ def _print_multi_seed_summary( f"{m['miscoverage']:<12.4f} {m['avg_set_size']:<12.2f}") print("\n" + "=" * 80) - print(f"LABEL summary (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(" Method: LABEL") print("=" * 80) print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}") print(f" ROC-AUC: {roc_aucs.mean():.4f} \u00b1 {roc_aucs.std():.4f}") @@ -334,7 +348,7 @@ def _main(args: argparse.Namespace) -> None: print("STEP 1: Load TUAB + build task dataset (shared across all seeds)") print("=" * 80) dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGAbnormalTUAB()) + sample_dataset = dataset.set_task(EEGAbnormalTUAB(normalization="95th_percentile"), num_workers=16) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/conformal_eeg/tuab_covariate_shift_conformal.py b/examples/conformal_eeg/tuab_covariate_shift_conformal.py index 33a810ab1..11460d139 100644 --- a/examples/conformal_eeg/tuab_covariate_shift_conformal.py +++ b/examples/conformal_eeg/tuab_covariate_shift_conformal.py @@ -129,13 +129,16 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--weights-dir", type=str, - default="weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB", - help="Root folder of fine-tuned TFM classifier checkpoints (only with --model tfm).", + default="/shared/eng/conformal_eeg", + help="Root folder of TFM classifier checkpoints (only with --model tfm). " + "If the directory contains tfm_encoder_best_model_tuab.pth directly, " + "that single checkpoint is used for all seeds (PI TUAB setup). " + "Otherwise expects per-seed subdirs {base}_1..N/best_model.pth.", ) parser.add_argument( "--tokenizer-weights", type=str, - default="weightfiles/tfm_tokenizer_last.pth", + default="/shared/eng/conformal_eeg/tfm_tokenizer_tuab.pth", help="Path to the pre-trained TFM tokenizer weights (only with --model tfm).", ) parser.add_argument( @@ -158,9 +161,19 @@ def _do_split(dataset, ratios, seed, split_type): def _load_tfm_weights(model, args, run_idx: int) -> None: - """Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based).""" - base = os.path.basename(args.weights_dir) - classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth") + """Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based). + + Supports two layouts: + - Single classifier (PI TUAB setup): weights_dir/tfm_encoder_best_model_tuab.pth + Used for all seeds — only the data split varies across runs. + - Per-seed subdirs: weights_dir/{base}_{run_idx+1}/best_model.pth + """ + single = os.path.join(args.weights_dir, "tfm_encoder_best_model_tuab.pth") + if os.path.isfile(single): + classifier_path = single + else: + base = os.path.basename(args.weights_dir) + classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth") print(f" Loading TFM weights (run {run_idx + 1}): {classifier_path}") model.load_pretrained_weights( tokenizer_checkpoint_path=args.tokenizer_weights, @@ -299,7 +312,7 @@ def _print_multi_seed_summary( n_runs = len(all_metrics) print("\n" + "=" * 80) - print("Per-run CovariateLabel results (fixed test set = TUH eval partition)") + print(f"Per-run results — alpha={alpha} (CovariateLabel, fixed test set = TUH eval partition)") print("=" * 80) print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'ROC-AUC':<10} {'F1':<8} " f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") @@ -311,7 +324,8 @@ def _print_multi_seed_summary( f"{m['miscoverage']:<12.4f} {m['avg_set_size']:<12.2f}") print("\n" + "=" * 80) - print(f"CovariateLabel summary (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(" Method: CovariateLabel") print("=" * 80) print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}") print(f" ROC-AUC: {roc_aucs.mean():.4f} \u00b1 {roc_aucs.std():.4f}") @@ -350,7 +364,7 @@ def _main(args: argparse.Namespace) -> None: print("STEP 1: Load TUAB + build task dataset (shared across all seeds)") print("=" * 80) dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGAbnormalTUAB()) + sample_dataset = dataset.set_task(EEGAbnormalTUAB(normalization="95th_percentile"), num_workers=16) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/conformal_eeg/tuab_kmeans_conformal.py b/examples/conformal_eeg/tuab_kmeans_conformal.py index 67152bfff..99b9742f4 100644 --- a/examples/conformal_eeg/tuab_kmeans_conformal.py +++ b/examples/conformal_eeg/tuab_kmeans_conformal.py @@ -132,13 +132,16 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--weights-dir", type=str, - default="weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB", - help="Root folder of fine-tuned TFM classifier checkpoints (only with --model tfm).", + default="/shared/eng/conformal_eeg", + help="Root folder of TFM classifier checkpoints (only with --model tfm). " + "If the directory contains tfm_encoder_best_model_tuab.pth directly, " + "that single checkpoint is used for all seeds (PI TUAB setup). " + "Otherwise expects per-seed subdirs {base}_1..N/best_model.pth.", ) parser.add_argument( "--tokenizer-weights", type=str, - default="weightfiles/tfm_tokenizer_last.pth", + default="/shared/eng/conformal_eeg/tfm_tokenizer_tuab.pth", help="Path to the pre-trained TFM tokenizer weights (only with --model tfm).", ) parser.add_argument( @@ -161,9 +164,19 @@ def _do_split(dataset, ratios, seed, split_type): def _load_tfm_weights(model, args, run_idx: int) -> None: - """Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based).""" - base = os.path.basename(args.weights_dir) - classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth") + """Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based). + + Supports two layouts: + - Single classifier (PI TUAB setup): weights_dir/tfm_encoder_best_model_tuab.pth + Used for all seeds — only the data split varies across runs. + - Per-seed subdirs: weights_dir/{base}_{run_idx+1}/best_model.pth + """ + single = os.path.join(args.weights_dir, "tfm_encoder_best_model_tuab.pth") + if os.path.isfile(single): + classifier_path = single + else: + base = os.path.basename(args.weights_dir) + classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth") print(f" Loading TFM weights (run {run_idx + 1}): {classifier_path}") model.load_pretrained_weights( tokenizer_checkpoint_path=args.tokenizer_weights, @@ -308,7 +321,7 @@ def _print_multi_seed_summary( n_runs = len(all_metrics) print("\n" + "=" * 80) - print("Per-run ClusterLabel results (fixed test set = TUH eval partition)") + print(f"Per-run results — alpha={alpha} (ClusterLabel, fixed test set = TUH eval partition)") print("=" * 80) print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'ROC-AUC':<10} {'F1':<8} " f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") @@ -320,7 +333,8 @@ def _print_multi_seed_summary( f"{m['miscoverage']:<12.4f} {m['avg_set_size']:<12.2f}") print("\n" + "=" * 80) - print(f"ClusterLabel summary (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(" Method: ClusterLabel") print("=" * 80) print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}") print(f" ROC-AUC: {roc_aucs.mean():.4f} \u00b1 {roc_aucs.std():.4f}") @@ -360,7 +374,7 @@ def _main(args: argparse.Namespace) -> None: print("STEP 1: Load TUAB + build task dataset (shared across all seeds)") print("=" * 80) dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGAbnormalTUAB()) + sample_dataset = dataset.set_task(EEGAbnormalTUAB(normalization="95th_percentile"), num_workers=16) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/conformal_eeg/tuab_ncp_conformal.py b/examples/conformal_eeg/tuab_ncp_conformal.py index be658d794..90166686a 100644 --- a/examples/conformal_eeg/tuab_ncp_conformal.py +++ b/examples/conformal_eeg/tuab_ncp_conformal.py @@ -136,13 +136,16 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--weights-dir", type=str, - default="weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB", - help="Root folder of fine-tuned TFM classifier checkpoints (only with --model tfm).", + default="/shared/eng/conformal_eeg", + help="Root folder of TFM classifier checkpoints (only with --model tfm). " + "If the directory contains tfm_encoder_best_model_tuab.pth directly, " + "that single checkpoint is used for all seeds (PI TUAB setup). " + "Otherwise expects per-seed subdirs {base}_1..N/best_model.pth.", ) parser.add_argument( "--tokenizer-weights", type=str, - default="weightfiles/tfm_tokenizer_last.pth", + default="/shared/eng/conformal_eeg/tfm_tokenizer_tuab.pth", help="Path to the pre-trained TFM tokenizer weights (only with --model tfm).", ) parser.add_argument( @@ -165,9 +168,19 @@ def _do_split(dataset, ratios, seed, split_type): def _load_tfm_weights(model, args, run_idx: int) -> None: - """Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based).""" - base = os.path.basename(args.weights_dir) - classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth") + """Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based). + + Supports two layouts: + - Single classifier (PI TUAB setup): weights_dir/tfm_encoder_best_model_tuab.pth + Used for all seeds — only the data split varies across runs. + - Per-seed subdirs: weights_dir/{base}_{run_idx+1}/best_model.pth + """ + single = os.path.join(args.weights_dir, "tfm_encoder_best_model_tuab.pth") + if os.path.isfile(single): + classifier_path = single + else: + base = os.path.basename(args.weights_dir) + classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth") print(f" Loading TFM weights (run {run_idx + 1}): {classifier_path}") model.load_pretrained_weights( tokenizer_checkpoint_path=args.tokenizer_weights, @@ -319,7 +332,7 @@ def _run(args: argparse.Namespace) -> None: print("STEP 1: Load TUAB + build task dataset") print("=" * 80) dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGAbnormalTUAB()) + sample_dataset = dataset.set_task(EEGAbnormalTUAB(normalization="95th_percentile"), num_workers=16) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") @@ -410,7 +423,10 @@ def _run(args: argparse.Namespace) -> None: set_sizes = np.array([m["avg_set_size"] for m in mlist]) if not use_multi_seed: - print(f"\nNCP Results (alpha={alpha}):") + print("\n" + "=" * 80) + print(f"Summary — alpha={alpha} (single run, fixed test set)") + print(" Method: NeighborhoodLabel") + print("=" * 80) print(f" Accuracy: {accs[0]:.4f}") print(f" ROC-AUC: {roc_aucs[0]:.4f}") print(f" F1: {f1s[0]:.4f}") @@ -421,7 +437,7 @@ def _run(args: argparse.Namespace) -> None: print(f" k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") else: print("\n" + "=" * 80) - print(f"Per-run NCP results — alpha={alpha} (target coverage={1-alpha:.0%})") + print(f"Per-run results — alpha={alpha} (NeighborhoodLabel, fixed test set = TUH eval partition)") print("=" * 80) print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'ROC-AUC':<10} {'F1':<8} " f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") @@ -431,14 +447,15 @@ def _run(args: argparse.Namespace) -> None: f"{f1s[i]:<8.4f} {coverages[i]:<10.4f} {miscovs[i]:<12.4f} {set_sizes[i]:<12.2f}") print("\n" + "=" * 80) - print(f"NCP summary — alpha={alpha} (mean ± std over {n_runs} runs, fixed test set)") + print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(" Method: NeighborhoodLabel") print("=" * 80) - print(f" Accuracy: {accs.mean():.4f} ± {accs.std():.4f}") - print(f" ROC-AUC: {roc_aucs.mean():.4f} ± {roc_aucs.std():.4f}") - print(f" F1: {f1s.mean():.4f} ± {f1s.std():.4f}") - print(f" Empirical coverage: {coverages.mean():.4f} ± {coverages.std():.4f}") - print(f" Empirical miscoverage: {miscovs.mean():.4f} ± {miscovs.std():.4f}") - print(f" Average set size: {set_sizes.mean():.2f} ± {set_sizes.std():.2f}") + print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}") + print(f" ROC-AUC: {roc_aucs.mean():.4f} \u00b1 {roc_aucs.std():.4f}") + print(f" F1: {f1s.mean():.4f} \u00b1 {f1s.std():.4f}") + print(f" Empirical coverage: {coverages.mean():.4f} \u00b1 {coverages.std():.4f}") + print(f" Empirical miscoverage: {miscovs.mean():.4f} \u00b1 {miscovs.std():.4f}") + print(f" Average set size: {set_sizes.mean():.2f} \u00b1 {set_sizes.std():.2f}") print(f" Target coverage: {1 - alpha:.0%} (alpha={alpha})") print(f" k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") print(f" Test set size: {n_test} (fixed across runs)") diff --git a/examples/conformal_eeg/tuev_conventional_conformal.py b/examples/conformal_eeg/tuev_conventional_conformal.py index fe41aaff4..e5542c7d4 100644 --- a/examples/conformal_eeg/tuev_conventional_conformal.py +++ b/examples/conformal_eeg/tuev_conventional_conformal.py @@ -283,7 +283,7 @@ def _print_multi_seed_summary( n_runs = len(all_metrics) print("\n" + "=" * 80) - print("Per-run LABEL results (fixed test set = TUH eval partition)") + print(f"Per-run results — alpha={alpha} (LABEL, fixed test set = TUH eval partition)") print("=" * 80) print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'F1-Wt':<10} " f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") @@ -295,7 +295,8 @@ def _print_multi_seed_summary( f"{m['miscoverage']:<12.4f} {m['avg_set_size']:<12.2f}") print("\n" + "=" * 80) - print(f"LABEL summary (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(" Method: LABEL") print("=" * 80) print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}") print(f" F1 (weighted): {f1s.mean():.4f} \u00b1 {f1s.std():.4f}") @@ -332,7 +333,7 @@ def _main(args: argparse.Namespace) -> None: print("STEP 1: Load TUEV + build task dataset (shared across all seeds)") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV()) + sample_dataset = dataset.set_task(EEGEventsTUEV(normalization="95th_percentile"), num_workers=16) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/conformal_eeg/tuev_covariate_shift_conformal.py b/examples/conformal_eeg/tuev_covariate_shift_conformal.py index 97169a0d7..d1bcba20a 100644 --- a/examples/conformal_eeg/tuev_covariate_shift_conformal.py +++ b/examples/conformal_eeg/tuev_covariate_shift_conformal.py @@ -296,7 +296,7 @@ def _print_multi_seed_summary( n_runs = len(all_metrics) print("\n" + "=" * 80) - print("Per-run CovariateLabel results (fixed test set = TUH eval partition)") + print(f"Per-run results — alpha={alpha} (CovariateLabel, fixed test set = TUH eval partition)") print("=" * 80) print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'F1-Wt':<10} " f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") @@ -308,7 +308,8 @@ def _print_multi_seed_summary( f"{m['miscoverage']:<12.4f} {m['avg_set_size']:<12.2f}") print("\n" + "=" * 80) - print(f"CovariateLabel summary (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(" Method: CovariateLabel") print("=" * 80) print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}") print(f" F1 (weighted): {f1s.mean():.4f} \u00b1 {f1s.std():.4f}") @@ -345,7 +346,7 @@ def _main(args: argparse.Namespace) -> None: print("STEP 1: Load TUEV + build task dataset (shared across all seeds)") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV()) + sample_dataset = dataset.set_task(EEGEventsTUEV(normalization="95th_percentile"), num_workers=16) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/conformal_eeg/tuev_kmeans_conformal.py b/examples/conformal_eeg/tuev_kmeans_conformal.py index 598ad43b1..faad50eaa 100644 --- a/examples/conformal_eeg/tuev_kmeans_conformal.py +++ b/examples/conformal_eeg/tuev_kmeans_conformal.py @@ -305,7 +305,7 @@ def _print_multi_seed_summary( n_runs = len(all_metrics) print("\n" + "=" * 80) - print("Per-run ClusterLabel results (fixed test set = TUH eval partition)") + print(f"Per-run results — alpha={alpha} (ClusterLabel, fixed test set = TUH eval partition)") print("=" * 80) print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'F1-Wt':<10} " f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") @@ -317,7 +317,8 @@ def _print_multi_seed_summary( f"{m['miscoverage']:<12.4f} {m['avg_set_size']:<12.2f}") print("\n" + "=" * 80) - print(f"ClusterLabel summary (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(" Method: ClusterLabel") print("=" * 80) print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}") print(f" F1 (weighted): {f1s.mean():.4f} \u00b1 {f1s.std():.4f}") @@ -355,7 +356,7 @@ def _main(args: argparse.Namespace) -> None: print("STEP 1: Load TUEV + build task dataset (shared across all seeds)") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV()) + sample_dataset = dataset.set_task(EEGEventsTUEV(normalization="95th_percentile"), num_workers=16) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") diff --git a/examples/conformal_eeg/tuev_ncp_conformal.py b/examples/conformal_eeg/tuev_ncp_conformal.py index 2721656b0..77cc98475 100644 --- a/examples/conformal_eeg/tuev_ncp_conformal.py +++ b/examples/conformal_eeg/tuev_ncp_conformal.py @@ -319,7 +319,7 @@ def _run(args: argparse.Namespace) -> None: print("STEP 1: Load TUEV + build task dataset") print("=" * 80) dataset = TUEVDataset(root=str(root), subset=args.subset, dev=args.quick_test) - sample_dataset = dataset.set_task(EEGEventsTUEV()) + sample_dataset = dataset.set_task(EEGEventsTUEV(normalization="95th_percentile"), num_workers=16) if args.quick_test and len(sample_dataset) > quick_test_max_samples: sample_dataset = sample_dataset.subset(range(quick_test_max_samples)) print(f"Capped to {quick_test_max_samples} samples for quick-test.") @@ -409,7 +409,10 @@ def _run(args: argparse.Namespace) -> None: set_sizes = np.array([m["avg_set_size"] for m in mlist]) if not use_multi_seed: - print(f"\nNCP Results (alpha={alpha}):") + print("\n" + "=" * 80) + print(f"Summary — alpha={alpha} (single run, fixed test set)") + print(" Method: NeighborhoodLabel") + print("=" * 80) print(f" Accuracy: {accs[0]:.4f}") print(f" F1 (weighted): {f1s[0]:.4f}") print(f" Empirical coverage: {coverages[0]:.4f}") @@ -419,7 +422,7 @@ def _run(args: argparse.Namespace) -> None: print(f" k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") else: print("\n" + "=" * 80) - print(f"Per-run NCP results — alpha={alpha} (target coverage={1-alpha:.0%})") + print(f"Per-run results — alpha={alpha} (NeighborhoodLabel, fixed test set = TUH eval partition)") print("=" * 80) print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'F1-Wt':<10} " f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}") @@ -429,13 +432,14 @@ def _run(args: argparse.Namespace) -> None: f"{coverages[i]:<10.4f} {miscovs[i]:<12.4f} {set_sizes[i]:<12.2f}") print("\n" + "=" * 80) - print(f"NCP summary — alpha={alpha} (mean ± std over {n_runs} runs, fixed test set)") + print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)") + print(" Method: NeighborhoodLabel") print("=" * 80) - print(f" Accuracy: {accs.mean():.4f} ± {accs.std():.4f}") - print(f" F1 (weighted): {f1s.mean():.4f} ± {f1s.std():.4f}") - print(f" Empirical coverage: {coverages.mean():.4f} ± {coverages.std():.4f}") - print(f" Empirical miscoverage: {miscovs.mean():.4f} ± {miscovs.std():.4f}") - print(f" Average set size: {set_sizes.mean():.2f} ± {set_sizes.std():.2f}") + print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}") + print(f" F1 (weighted): {f1s.mean():.4f} \u00b1 {f1s.std():.4f}") + print(f" Empirical coverage: {coverages.mean():.4f} \u00b1 {coverages.std():.4f}") + print(f" Empirical miscoverage: {miscovs.mean():.4f} \u00b1 {miscovs.std():.4f}") + print(f" Average set size: {set_sizes.mean():.2f} \u00b1 {set_sizes.std():.2f}") print(f" Target coverage: {1 - alpha:.0%} (alpha={alpha})") print(f" k_neighbors: {args.k_neighbors}, lambda_L: {args.lambda_L}") print(f" Test set size: {n_test} (fixed across runs)") diff --git a/pyhealth/tasks/temple_university_EEG_tasks.py b/pyhealth/tasks/temple_university_EEG_tasks.py index fc2ba702a..13e8c7206 100644 --- a/pyhealth/tasks/temple_university_EEG_tasks.py +++ b/pyhealth/tasks/temple_university_EEG_tasks.py @@ -229,7 +229,16 @@ class EEGAbnormalTUAB(BaseTask): task_name: str = "EEG_abnormal" input_schema: Dict[str, str] = {"signal": "tensor", "stft": "tensor"} - output_schema: Dict[str, str] = {"label": "binary"} + # NOTE: TUAB is a binary classification task (normal=0 vs abnormal=1), but the + # output schema is intentionally set to "multiclass" rather than "binary". + # Reason: PyHealth's conformal prediction methods (LABEL, ClusterLabel, + # NeighborhoodLabel, CovariateLabel) require multiclass mode — they calibrate + # prediction sets by thresholding a full (n, K) probability matrix, which is + # only produced by a softmax output (multiclass). Binary mode uses sigmoid and + # outputs (n, 1), which is incompatible with the CP calibration math. + # For a 2-class problem, 2-class softmax is mathematically equivalent to + # sigmoid, so there is no loss of correctness, just a different representation. + output_schema: Dict[str, str] = {"label": "multiclass"} def __init__(self, resample_rate: float = 200, diff --git a/tests/core/test_tuab.py b/tests/core/test_tuab.py index bf6e25d45..2559b13a8 100644 --- a/tests/core/test_tuab.py +++ b/tests/core/test_tuab.py @@ -531,7 +531,7 @@ def test_task_schema_attributes(self): task = EEGAbnormalTUAB() self.assertEqual(task.task_name, "EEG_abnormal") self.assertEqual(task.input_schema, {"signal": "tensor", "stft": "tensor"}) - self.assertEqual(task.output_schema, {"label": "binary"}) + self.assertEqual(task.output_schema, {"label": "multiclass"}) def test_task_schema_no_stft(self): task = EEGAbnormalTUAB(compute_stft=False)