Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cxr/covid19cxr_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,7 @@
" # Input size is inferred automatically from image dimensions\n",
" result = chefer_gen.attribute(\n",
" interpolate=True,\n",
" class_index=pred_class,\n",
" target_class_idx=pred_class,\n",
" **batch\n",
" )\n",
" attr_map = result[\"image\"] # Keyed by task schema's feature key\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/cxr/covid19cxr_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
# Compute attribution for each class in the prediction set
overlays = []
for class_idx in predset_class_indices:
attr_map = chefer.attribute(class_index=class_idx, **batch)["image"]
attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"]
_, _, overlay = visualize_image_attr(
image=batch["image"][0],
attribution=attr_map[0, 0],
Expand Down
2 changes: 1 addition & 1 deletion examples/cxr/covid19cxr_tutorial_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
# Compute attribution for each class in the prediction set
overlays = []
for class_idx in predset_class_indices:
attr_map = chefer.attribute(class_index=class_idx, **batch)["image"]
attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"]
_, _, overlay = visualize_image_attr(
image=batch["image"][0],
attribution=attr_map[0, 0],
Expand Down
189 changes: 189 additions & 0 deletions examples/interpretability/custom_sample_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""Evaluate all interpretability methods on StageNet + MIMIC-IV dataset using comprehensiveness
and sufficiency metrics.

This example demonstrates:
1. Loading a pre-trained StageNet model with processors and MIMIC-IV dataset
2. Computing attributions with various interpretability methods
3. Evaluating attribution faithfulness with Comprehensiveness & Sufficiency for each method
4. Presenting results in a summary table
"""

import datetime
import argparse

import torch
from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient
from pyhealth.interpret.methods import *
from pyhealth.metrics.interpretability import evaluate_attribution
from pyhealth.metrics.interpretability.utils import SampleClass
from pyhealth.models import Transformer
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4
from pyhealth.trainer import Trainer
from pyhealth.datasets.utils import load_processors
from pathlib import Path
import pandas as pd

# python -u examples/interpretability/custom_sample_filter.py --pos_threshold 0.5 --neg_threshold 0.1 --device cuda:2
def main():
parser = argparse.ArgumentParser(
description="Comma separated list of interpretability methods to evaluate"
)
parser.add_argument(
"--pos_threshold",
type=float,
default=None,
help="Positive threshold for interpretability evaluation (default: 0.5).",
)
parser.add_argument(
"--neg_threshold",
type=float,
default=None,
help="Negative threshold for interpretability evaluation (default: 0.5).",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0",
help="Device to use for evaluation (default: cuda:0)",
)
args = parser.parse_args()
"""Main execution function."""
print("=" * 70)
print("Interpretability Metrics Example: Transformer + MIMIC-IV")
print("=" * 70)

now = datetime.datetime.now()
print(f"Start Time: {now.strftime('%Y-%m-%d %H:%M:%S')}")

# Set path
CACHE_DIR = Path("/home/yongdaf2/interpret/cache/mp_mimic4")
CKPTS_DIR = Path("/shared/eng/pyhealth_dka/ckpts/mp_transformer_mimic4")
OUTPUT_DIR = Path("/home/yongdaf2/interpret/output/mp_transformer_mimic4")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
CKPTS_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"\nUsing cache dir: {CACHE_DIR}")
print(f"Using checkpoints dir: {CKPTS_DIR}")
print(f"Using output dir: {OUTPUT_DIR}")

# Set device
device = args.device
print(f"\nUsing device: {device}")

# Load MIMIC-IV dataset
print("\n Loading MIMIC-IV dataset...")
base_dataset = MIMIC4Dataset(
ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
ehr_tables=[
"patients",
"admissions",
"diagnoses_icd",
"procedures_icd",
"labevents",
],
cache_dir=str(CACHE_DIR),
num_workers=16,
)

# Apply mortality prediction task
if not (CKPTS_DIR / "input_processors.pkl").exists():
raise FileNotFoundError(f"Input processors not found in {CKPTS_DIR}. ")
if not (CKPTS_DIR / "output_processors.pkl").exists():
raise FileNotFoundError(f"Output processors not found in {CKPTS_DIR}. ")
input_processors, output_processors = load_processors(str(CKPTS_DIR))
print("✓ Loaded input and output processors from checkpoint directory.")

sample_dataset = base_dataset.set_task(
MortalityPredictionStageNetMIMIC4(),
num_workers=16,
input_processors=input_processors,
output_processors=output_processors,
)
print(f"✓ Loaded {len(sample_dataset)} samples")

# Split dataset and get test loader
_, _, test_dataset = split_by_patient(sample_dataset, [0.9, 0.09, 0.01], seed=233)
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)
print(f"✓ Test set: {len(test_dataset)} samples")

# Initialize and load pre-trained model
print("\n Loading pre-trained Transformer model...")
model = Transformer(
dataset=sample_dataset,
embedding_dim=128,
heads=4,
dropout=0.3,
num_layers=3,
)

trainer = Trainer(model=model, device=device)
trainer.load_ckpt(str(CKPTS_DIR / "best.ckpt"))
model = model.to(device)
model.eval()
print(f"✓ Loaded checkpoint: {CKPTS_DIR / 'best.ckpt'}")
print(f"✓ Model moved to {device}")

pos_threshold = args.pos_threshold
neg_threshold = args.neg_threshold
def sample_filter_fn(
y_probs: torch.Tensor,
classifier_type: str,
) -> torch.Tensor:
"""
Custom sample filter function that classifies samples based on
positive and negative probability thresholds.

negative samples: 0 < y_probs < neg_threshold
ignored samples: neg_threshold <= y_probs < pos_threshold
positive samples: y_probs >= pos_threshold
"""
nonlocal pos_threshold, neg_threshold
batch_size = y_probs.shape[0]
result = torch.full(
(batch_size,),
SampleClass.POSITIVE,
dtype=torch.long,
device=y_probs.device,
)
if classifier_type in ("binary", "multilabel"):
if pos_threshold is not None:
result[y_probs < pos_threshold] = SampleClass.IGNORE
if neg_threshold is not None:
result[y_probs < neg_threshold] = SampleClass.NEGATIVE
return result

interpreter = IntegratedGradients(model, use_embeddings=True)
print(f"\nEvaluating using Integrated Gradients...")

# Option 1: Functional API (simple one-off evaluation)
print("\nEvaluating with Functional API on full dataset...")
print("Using: evaluate_attribution(model, dataloader, method, ...)")

results_functional = evaluate_attribution(
model,
test_loader,
interpreter,
metrics=["comprehensiveness", "sufficiency"],
percentages=[25, 50, 99],
sample_filter=sample_filter_fn,
)

print("\n" + "=" * 70)
print("Dataset-Wide Results (Functional API)")
print("=" * 70)
comp = results_functional["comprehensiveness"]
suff = results_functional["sufficiency"]
print(f"\nComprehensiveness: {comp:.4f}")
print(f"Sufficiency: {suff:.4f}")

print("")
print("=" * 70)
print("Summary of Results for All Methods")
print({"Method": "Integrated Gradients", "Comprehensiveness": comp, "Sufficiency": suff})

end = datetime.datetime.now()
print(f"End Time: {end.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total Duration: {end - now}")

if __name__ == "__main__":
main()
46 changes: 43 additions & 3 deletions pyhealth/interpret/methods/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from abc import ABC, abstractmethod
from typing import Dict, cast
from typing import Dict, Optional, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -138,8 +138,12 @@ def attribute(
by the task's ``input_schema``.
- Label key (optional): Ground truth labels, may be needed
by some methods for loss computation.
- ``class_index`` (optional): Target class for attribution.
If not provided, uses the predicted class.
- ``target_class_idx`` (Optional[int]): Target class for
attribution. For binary classification (single logit
output), this is a no-op because there is only one
output. For multi-class or multi-label classification,
specifies which class index to explain. If not provided,
uses the argmax of logits.
- Additional method-specific parameters (e.g., ``baseline``,
``steps``, ``interpolate``).

Expand Down Expand Up @@ -207,6 +211,42 @@ def attribute(
"""
pass

def _resolve_target_indices(
self,
logits: torch.Tensor,
target_class_idx: Optional[int],
) -> torch.Tensor:
"""Resolve target class indices for attribution.

Returns a ``[batch]`` tensor of class indices identifying which
logit to explain. All prediction modes share this single code
path:

* **Binary** (single logit): ``target_class_idx`` is a no-op
because there is only one output. Always returns zeros
(index 0).
* **Multi-class / multi-label**: uses ``target_class_idx`` if
given, otherwise the argmax of logits.

Args:
logits: Model output logits, shape ``[batch, num_classes]``.
target_class_idx: Optional user-specified class index.

Returns:
``torch.LongTensor`` of shape ``[batch]``.
"""
if logits.shape[-1] == 1:
# Single logit output — nothing to select.
return torch.zeros(
logits.shape[0], device=logits.device, dtype=torch.long,
)
if target_class_idx is not None:
return torch.full(
(logits.shape[0],), target_class_idx,
device=logits.device, dtype=torch.long,
)
return logits.argmax(dim=-1)

def _prediction_mode(self) -> str:
"""Resolve the prediction mode from the model.

Expand Down
19 changes: 8 additions & 11 deletions pyhealth/interpret/methods/chefer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class CheferRelevance(BaseInterpreter):
>>> print(attributions["conditions"].shape) # [batch, num_tokens]
>>>
>>> # Optional: attribute to a specific class (e.g., class 1)
>>> attributions = interpreter.attribute(class_index=1, **batch)
>>> attributions = interpreter.attribute(target_class_idx=1, **batch)
"""

def __init__(self, model: BaseModel):
Expand All @@ -139,14 +139,16 @@ def __init__(self, model: BaseModel):

def attribute(
self,
class_index: Optional[int] = None,
target_class_idx: Optional[int] = None,
**data,
) -> Dict[str, torch.Tensor]:
"""Compute relevance scores for each input token.

Args:
class_index: Target class index to compute attribution for.
If None (default), uses the model's predicted class.
target_class_idx: Target class index to compute attribution for.
If None (default), uses the argmax of model output.
For binary classification (single logit output), this is
a no-op because there is only one output.
**data: Input data from dataloader batch containing feature
keys and label key.

Expand All @@ -163,15 +165,10 @@ def attribute(
self.model.set_attention_hooks(False)

# --- 2. Backward from target class ---
if class_index is None:
class_index_t = torch.argmax(logits, dim=-1)
elif isinstance(class_index, int):
class_index_t = torch.tensor(class_index)
else:
class_index_t = class_index
target_indices = self._resolve_target_indices(logits, target_class_idx)

one_hot = F.one_hot(
class_index_t.detach().clone(), logits.size(1)
target_indices.detach().clone(), logits.size(1)
).float()
one_hot = one_hot.requires_grad_(True)
scalar = torch.sum(one_hot.to(logits.device) * logits)
Expand Down
Loading
Loading