diff --git a/examples/cxr/covid19cxr_tutorial.ipynb b/examples/cxr/covid19cxr_tutorial.ipynb index 2a04844c5..ec10756a1 100644 --- a/examples/cxr/covid19cxr_tutorial.ipynb +++ b/examples/cxr/covid19cxr_tutorial.ipynb @@ -1339,7 +1339,7 @@ " # Input size is inferred automatically from image dimensions\n", " result = chefer_gen.attribute(\n", " interpolate=True,\n", - " class_index=pred_class,\n", + " target_class_idx=pred_class,\n", " **batch\n", " )\n", " attr_map = result[\"image\"] # Keyed by task schema's feature key\n", diff --git a/examples/cxr/covid19cxr_tutorial.py b/examples/cxr/covid19cxr_tutorial.py index 0f24f4b58..06b134f93 100644 --- a/examples/cxr/covid19cxr_tutorial.py +++ b/examples/cxr/covid19cxr_tutorial.py @@ -131,7 +131,7 @@ # Compute attribution for each class in the prediction set overlays = [] for class_idx in predset_class_indices: - attr_map = chefer.attribute(class_index=class_idx, **batch)["image"] + attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"] _, _, overlay = visualize_image_attr( image=batch["image"][0], attribution=attr_map[0, 0], diff --git a/examples/cxr/covid19cxr_tutorial_display.py b/examples/cxr/covid19cxr_tutorial_display.py index 3f6a33b82..f3a4acddb 100644 --- a/examples/cxr/covid19cxr_tutorial_display.py +++ b/examples/cxr/covid19cxr_tutorial_display.py @@ -128,7 +128,7 @@ # Compute attribution for each class in the prediction set overlays = [] for class_idx in predset_class_indices: - attr_map = chefer.attribute(class_index=class_idx, **batch)["image"] + attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"] _, _, overlay = visualize_image_attr( image=batch["image"][0], attribution=attr_map[0, 0], diff --git a/examples/interpretability/custom_sample_filter.py b/examples/interpretability/custom_sample_filter.py new file mode 100644 index 000000000..da59546c5 --- /dev/null +++ b/examples/interpretability/custom_sample_filter.py @@ -0,0 +1,189 @@ +"""Evaluate all interpretability methods on StageNet + MIMIC-IV dataset using comprehensiveness +and sufficiency metrics. + +This example demonstrates: +1. Loading a pre-trained StageNet model with processors and MIMIC-IV dataset +2. Computing attributions with various interpretability methods +3. Evaluating attribution faithfulness with Comprehensiveness & Sufficiency for each method +4. Presenting results in a summary table +""" + +import datetime +import argparse + +import torch +from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient +from pyhealth.interpret.methods import * +from pyhealth.metrics.interpretability import evaluate_attribution +from pyhealth.metrics.interpretability.utils import SampleClass +from pyhealth.models import Transformer +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer +from pyhealth.datasets.utils import load_processors +from pathlib import Path +import pandas as pd + +# python -u examples/interpretability/custom_sample_filter.py --pos_threshold 0.5 --neg_threshold 0.1 --device cuda:2 +def main(): + parser = argparse.ArgumentParser( + description="Comma separated list of interpretability methods to evaluate" + ) + parser.add_argument( + "--pos_threshold", + type=float, + default=None, + help="Positive threshold for interpretability evaluation (default: 0.5).", + ) + parser.add_argument( + "--neg_threshold", + type=float, + default=None, + help="Negative threshold for interpretability evaluation (default: 0.5).", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to use for evaluation (default: cuda:0)", + ) + args = parser.parse_args() + """Main execution function.""" + print("=" * 70) + print("Interpretability Metrics Example: Transformer + MIMIC-IV") + print("=" * 70) + + now = datetime.datetime.now() + print(f"Start Time: {now.strftime('%Y-%m-%d %H:%M:%S')}") + + # Set path + CACHE_DIR = Path("/home/yongdaf2/interpret/cache/mp_mimic4") + CKPTS_DIR = Path("/shared/eng/pyhealth_dka/ckpts/mp_transformer_mimic4") + OUTPUT_DIR = Path("/home/yongdaf2/interpret/output/mp_transformer_mimic4") + CACHE_DIR.mkdir(parents=True, exist_ok=True) + CKPTS_DIR.mkdir(parents=True, exist_ok=True) + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + print(f"\nUsing cache dir: {CACHE_DIR}") + print(f"Using checkpoints dir: {CKPTS_DIR}") + print(f"Using output dir: {OUTPUT_DIR}") + + # Set device + device = args.device + print(f"\nUsing device: {device}") + + # Load MIMIC-IV dataset + print("\n Loading MIMIC-IV dataset...") + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + cache_dir=str(CACHE_DIR), + num_workers=16, + ) + + # Apply mortality prediction task + if not (CKPTS_DIR / "input_processors.pkl").exists(): + raise FileNotFoundError(f"Input processors not found in {CKPTS_DIR}. ") + if not (CKPTS_DIR / "output_processors.pkl").exists(): + raise FileNotFoundError(f"Output processors not found in {CKPTS_DIR}. ") + input_processors, output_processors = load_processors(str(CKPTS_DIR)) + print("✓ Loaded input and output processors from checkpoint directory.") + + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=16, + input_processors=input_processors, + output_processors=output_processors, + ) + print(f"✓ Loaded {len(sample_dataset)} samples") + + # Split dataset and get test loader + _, _, test_dataset = split_by_patient(sample_dataset, [0.9, 0.09, 0.01], seed=233) + test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False) + print(f"✓ Test set: {len(test_dataset)} samples") + + # Initialize and load pre-trained model + print("\n Loading pre-trained Transformer model...") + model = Transformer( + dataset=sample_dataset, + embedding_dim=128, + heads=4, + dropout=0.3, + num_layers=3, + ) + + trainer = Trainer(model=model, device=device) + trainer.load_ckpt(str(CKPTS_DIR / "best.ckpt")) + model = model.to(device) + model.eval() + print(f"✓ Loaded checkpoint: {CKPTS_DIR / 'best.ckpt'}") + print(f"✓ Model moved to {device}") + + pos_threshold = args.pos_threshold + neg_threshold = args.neg_threshold + def sample_filter_fn( + y_probs: torch.Tensor, + classifier_type: str, + ) -> torch.Tensor: + """ + Custom sample filter function that classifies samples based on + positive and negative probability thresholds. + + negative samples: 0 < y_probs < neg_threshold + ignored samples: neg_threshold <= y_probs < pos_threshold + positive samples: y_probs >= pos_threshold + """ + nonlocal pos_threshold, neg_threshold + batch_size = y_probs.shape[0] + result = torch.full( + (batch_size,), + SampleClass.POSITIVE, + dtype=torch.long, + device=y_probs.device, + ) + if classifier_type in ("binary", "multilabel"): + if pos_threshold is not None: + result[y_probs < pos_threshold] = SampleClass.IGNORE + if neg_threshold is not None: + result[y_probs < neg_threshold] = SampleClass.NEGATIVE + return result + + interpreter = IntegratedGradients(model, use_embeddings=True) + print(f"\nEvaluating using Integrated Gradients...") + + # Option 1: Functional API (simple one-off evaluation) + print("\nEvaluating with Functional API on full dataset...") + print("Using: evaluate_attribution(model, dataloader, method, ...)") + + results_functional = evaluate_attribution( + model, + test_loader, + interpreter, + metrics=["comprehensiveness", "sufficiency"], + percentages=[25, 50, 99], + sample_filter=sample_filter_fn, + ) + + print("\n" + "=" * 70) + print("Dataset-Wide Results (Functional API)") + print("=" * 70) + comp = results_functional["comprehensiveness"] + suff = results_functional["sufficiency"] + print(f"\nComprehensiveness: {comp:.4f}") + print(f"Sufficiency: {suff:.4f}") + + print("") + print("=" * 70) + print("Summary of Results for All Methods") + print({"Method": "Integrated Gradients", "Comprehensiveness": comp, "Sufficiency": suff}) + + end = datetime.datetime.now() + print(f"End Time: {end.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"Total Duration: {end - now}") + +if __name__ == "__main__": + main() diff --git a/pyhealth/interpret/methods/base_interpreter.py b/pyhealth/interpret/methods/base_interpreter.py index de75c897a..fb17690ae 100644 --- a/pyhealth/interpret/methods/base_interpreter.py +++ b/pyhealth/interpret/methods/base_interpreter.py @@ -10,7 +10,7 @@ """ from abc import ABC, abstractmethod -from typing import Dict, cast +from typing import Dict, Optional, cast import torch import torch.nn as nn @@ -138,8 +138,12 @@ def attribute( by the task's ``input_schema``. - Label key (optional): Ground truth labels, may be needed by some methods for loss computation. - - ``class_index`` (optional): Target class for attribution. - If not provided, uses the predicted class. + - ``target_class_idx`` (Optional[int]): Target class for + attribution. For binary classification (single logit + output), this is a no-op because there is only one + output. For multi-class or multi-label classification, + specifies which class index to explain. If not provided, + uses the argmax of logits. - Additional method-specific parameters (e.g., ``baseline``, ``steps``, ``interpolate``). @@ -207,6 +211,42 @@ def attribute( """ pass + def _resolve_target_indices( + self, + logits: torch.Tensor, + target_class_idx: Optional[int], + ) -> torch.Tensor: + """Resolve target class indices for attribution. + + Returns a ``[batch]`` tensor of class indices identifying which + logit to explain. All prediction modes share this single code + path: + + * **Binary** (single logit): ``target_class_idx`` is a no-op + because there is only one output. Always returns zeros + (index 0). + * **Multi-class / multi-label**: uses ``target_class_idx`` if + given, otherwise the argmax of logits. + + Args: + logits: Model output logits, shape ``[batch, num_classes]``. + target_class_idx: Optional user-specified class index. + + Returns: + ``torch.LongTensor`` of shape ``[batch]``. + """ + if logits.shape[-1] == 1: + # Single logit output — nothing to select. + return torch.zeros( + logits.shape[0], device=logits.device, dtype=torch.long, + ) + if target_class_idx is not None: + return torch.full( + (logits.shape[0],), target_class_idx, + device=logits.device, dtype=torch.long, + ) + return logits.argmax(dim=-1) + def _prediction_mode(self) -> str: """Resolve the prediction mode from the model. diff --git a/pyhealth/interpret/methods/chefer.py b/pyhealth/interpret/methods/chefer.py index 5efca68eb..26ce6ffd2 100644 --- a/pyhealth/interpret/methods/chefer.py +++ b/pyhealth/interpret/methods/chefer.py @@ -128,7 +128,7 @@ class CheferRelevance(BaseInterpreter): >>> print(attributions["conditions"].shape) # [batch, num_tokens] >>> >>> # Optional: attribute to a specific class (e.g., class 1) - >>> attributions = interpreter.attribute(class_index=1, **batch) + >>> attributions = interpreter.attribute(target_class_idx=1, **batch) """ def __init__(self, model: BaseModel): @@ -139,14 +139,16 @@ def __init__(self, model: BaseModel): def attribute( self, - class_index: Optional[int] = None, + target_class_idx: Optional[int] = None, **data, ) -> Dict[str, torch.Tensor]: """Compute relevance scores for each input token. Args: - class_index: Target class index to compute attribution for. - If None (default), uses the model's predicted class. + target_class_idx: Target class index to compute attribution for. + If None (default), uses the argmax of model output. + For binary classification (single logit output), this is + a no-op because there is only one output. **data: Input data from dataloader batch containing feature keys and label key. @@ -163,15 +165,10 @@ def attribute( self.model.set_attention_hooks(False) # --- 2. Backward from target class --- - if class_index is None: - class_index_t = torch.argmax(logits, dim=-1) - elif isinstance(class_index, int): - class_index_t = torch.tensor(class_index) - else: - class_index_t = class_index + target_indices = self._resolve_target_indices(logits, target_class_idx) one_hot = F.one_hot( - class_index_t.detach().clone(), logits.size(1) + target_indices.detach().clone(), logits.size(1) ).float() one_hot = one_hot.requires_grad_(True) scalar = torch.sum(one_hot.to(logits.device) * logits) diff --git a/pyhealth/interpret/methods/deeplift.py b/pyhealth/interpret/methods/deeplift.py index 29f99d795..8f00f8a8b 100644 --- a/pyhealth/interpret/methods/deeplift.py +++ b/pyhealth/interpret/methods/deeplift.py @@ -411,35 +411,7 @@ def attribute( with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] - mode = self._prediction_mode() - if mode == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif mode == "multiclass": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = torch.argmax(base_logits, dim=-1) - target = F.one_hot( - target, num_classes=base_logits.shape[-1] - ).float() - elif mode == "multilabel": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = (torch.sigmoid(base_logits) > 0.5).float() - else: - raise ValueError( - "Unsupported prediction mode for DeepLIFT attribution." - ) + target_indices = self._resolve_target_indices(base_logits, target_class_idx) # Generate baselines if baseline is None: @@ -491,7 +463,7 @@ def attribute( inputs=inputs, xs=values, bs=baselines, - target=target, + target_indices=target_indices, token_keys=token_keys, ) @@ -505,7 +477,7 @@ def _deeplift( inputs: Dict[str, tuple[torch.Tensor, ...]], xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], - target: torch.Tensor, + target_indices: torch.Tensor, token_keys: set[str], ) -> Dict[str, torch.Tensor]: """Core DeepLIFT computation using the Rescale rule. @@ -517,8 +489,7 @@ def _deeplift( inputs: Full input tuples keyed by feature name. xs: Input values (embedded if token features with use_embeddings). bs: Baseline values (embedded if token features with use_embeddings). - target: Target tensor for computing the scalar output to - differentiate (one-hot for multiclass, class idx for binary). + target_indices: [batch] tensor of target class indices. token_keys: Set of feature keys that are token (already embedded). Returns: @@ -590,9 +561,9 @@ def _maybe_embed_continuous(value_dict: dict[str, torch.Tensor]) -> dict[str, to baseline_logits = baseline_output["logit"] # type: ignore[index] # Compute per-sample target outputs - target_output = self._compute_target_output(logits, target) + target_output = self._compute_target_output(logits, target_indices) baseline_target_output = self._compute_target_output( - baseline_logits, target + baseline_logits, target_indices ) self.model.zero_grad(set_to_none=True) @@ -626,46 +597,22 @@ def _maybe_embed_continuous(value_dict: dict[str, torch.Tensor]) -> dict[str, to def _compute_target_output( self, logits: torch.Tensor, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Compute per-sample target output. - Creates a differentiable per-sample scalar from the model logits - that, when summed and differentiated, gives the gradient of the - target class logit w.r.t. the input. + Selects the target-class logit for each sample. Args: - logits: Model output logits, shape [batch, num_classes] or - [batch, 1]. - target: Target tensor. For binary: [batch] or [1] with 0/1 - class indices. For multiclass/multilabel: [batch, num_classes] - one-hot or multi-hot tensor. + logits: Model output logits, shape [batch, num_classes]. + target_indices: [batch] tensor of target class indices. Returns: Per-sample target output tensor, shape [batch]. """ - target_f = target.to(logits.device).float() - mode = self._prediction_mode() - - if mode == "binary": - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(-1) - target_f = target_f.expand_as(logits) - signs = 2.0 * target_f - 1.0 - # Sum over all dims except batch to get per-sample scalar - per_sample = (signs * logits) - if per_sample.dim() > 1: - per_sample = per_sample.sum(dim=tuple(range(1, per_sample.dim()))) - return per_sample - else: - # multiclass or multilabel: target is one-hot/multi-hot - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(0) - target_f = target_f.expand_as(logits) - per_sample = (target_f * logits) - if per_sample.dim() > 1: - per_sample = per_sample.sum(dim=tuple(range(1, per_sample.dim()))) - return per_sample + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1) # ------------------------------------------------------------------ # Completeness enforcement diff --git a/pyhealth/interpret/methods/gim.py b/pyhealth/interpret/methods/gim.py index d1b74b573..abbb9388a 100644 --- a/pyhealth/interpret/methods/gim.py +++ b/pyhealth/interpret/methods/gim.py @@ -365,8 +365,9 @@ def attribute( """Compute GIM attributions for a batch. Args: - target_class_idx: Target class index for attribution. If None, - uses the model's predicted class. + target_class_idx: Target class index for attribution. For + binary classification (single logit output), this is a + no-op. If None, uses the argmax of model output. **kwargs: Input data dictionary from a dataloader batch containing feature tensors or tuples of tensors for each modality, plus optional label tensors. @@ -419,35 +420,7 @@ def attribute( with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] - mode = self._prediction_mode() - if mode == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif mode == "multiclass": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = torch.argmax(base_logits, dim=-1) - target = F.one_hot( - target, num_classes=base_logits.shape[-1] - ).float() - elif mode == "multilabel": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = (torch.sigmoid(base_logits) > 0.5).float() - else: - raise ValueError( - "Unsupported prediction mode for GIM attribution." - ) + target_indices = self._resolve_target_indices(base_logits, target_class_idx) # Embed values and detach for gradient attribution. # Split features by type using is_token(): @@ -521,7 +494,7 @@ def attribute( output = self.model.forward_from_embedding(**forward_inputs) logits = output["logit"] # type: ignore[assignment] - target_output = self._compute_target_output(logits, target) + target_output = self._compute_target_output(logits, target_indices) # Clear stale gradients, then backpropagate through the # GIM-modified computational graph. @@ -552,39 +525,23 @@ def attribute( def _compute_target_output( self, logits: torch.Tensor, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Compute scalar target output for backpropagation. - Creates a differentiable scalar from the model logits that, - when differentiated, gives the gradient of the target class - logit w.r.t. the input. + Selects the target-class logit for each sample and sums over + the batch to produce a single differentiable scalar. Args: - logits: Model output logits, shape [batch, num_classes] or - [batch, 1]. - target: Target tensor. For binary: [batch] or [1] with 0/1 - class indices. For multiclass/multilabel: [batch, num_classes] - one-hot or multi-hot tensor. + logits: Model output logits, shape [batch, num_classes]. + target_indices: [batch] tensor of target class indices. Returns: Scalar tensor for backpropagation. """ - target_f = target.to(logits.device).float() - mode = self._prediction_mode() - - if mode == "binary": - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(-1) - target_f = target_f.expand_as(logits) - signs = 2.0 * target_f - 1.0 - return (signs * logits).sum() - else: - # multiclass or multilabel: target is one-hot/multi-hot - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(0) - target_f = target_f.expand_as(logits) - return (target_f * logits).sum() + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1).sum() # ------------------------------------------------------------------ # Utility helpers diff --git a/pyhealth/interpret/methods/ig_gim.py b/pyhealth/interpret/methods/ig_gim.py index a33f5529b..49c2fa6c0 100644 --- a/pyhealth/interpret/methods/ig_gim.py +++ b/pyhealth/interpret/methods/ig_gim.py @@ -105,8 +105,9 @@ def attribute( near-zero for continuous features). steps: Number of interpolation steps. Overrides the instance default when given. - target_class_idx: Target class for attribution. ``None`` - uses the model's predicted class. + target_class_idx: Target class for attribution. For binary + classification (single logit output), this is a no-op. + ``None`` uses the argmax of model output. **kwargs: Dataloader batch (feature tensors + optional labels). Returns: @@ -155,10 +156,7 @@ def attribute( with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] - mode = self._prediction_mode() - target = self._resolve_target( - base_logits, mode, target_class_idx, device - ) + target_indices = self._resolve_target_indices(base_logits, target_class_idx) # ----- baselines ----- if baseline is None: @@ -201,7 +199,7 @@ def attribute( xs=values, bs=baselines, steps=steps, - target=target, + target_indices=target_indices, token_keys=token_keys, continuous_keys=continuous_keys, ) @@ -217,7 +215,7 @@ def _integrated_gradients_gim( xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], steps: int, - target: torch.Tensor, + target_indices: torch.Tensor, token_keys: set[str], continuous_keys: set[str], ) -> Dict[str, torch.Tensor]: @@ -280,7 +278,7 @@ def _integrated_gradients_gim( with _GIMHookContext(self.model, self.temperature): output = self.model.forward_from_embedding(**forward_inputs) logits = output["logit"] - target_output = self._compute_target_output(logits, target) + target_output = self._compute_target_output(logits, target_indices) self.model.zero_grad(set_to_none=True) target_output.backward(retain_graph=True) @@ -308,58 +306,16 @@ def _integrated_gradients_gim( # ------------------------------------------------------------------ # Target helpers (shared logic with IG / GIM) # ------------------------------------------------------------------ - @staticmethod - def _resolve_target( - logits: torch.Tensor, - mode: str, - target_class_idx: Optional[int], - device: torch.device, - ) -> torch.Tensor: - """Convert logits and optional class index into a target tensor.""" - if mode == "binary": - if target_class_idx is not None: - return torch.tensor([target_class_idx], device=device) - return (torch.sigmoid(logits) > 0.5).long() - - if mode == "multiclass": - if target_class_idx is not None: - return F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=logits.shape[-1], - ).float() - target = torch.argmax(logits, dim=-1) - return F.one_hot(target, num_classes=logits.shape[-1]).float() - - if mode == "multilabel": - if target_class_idx is not None: - return F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=logits.shape[-1], - ).float() - return (torch.sigmoid(logits) > 0.5).float() - - raise ValueError(f"Unsupported prediction mode: {mode}") def _compute_target_output( self, logits: torch.Tensor, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Scalar target output for backpropagation.""" - target_f = target.to(logits.device).float() - mode = self._prediction_mode() - - if mode == "binary": - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(-1) - target_f = target_f.expand_as(logits) - signs = 2.0 * target_f - 1.0 - return (signs * logits).sum() - else: - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(0) - target_f = target_f.expand_as(logits) - return (target_f * logits).sum() + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1).sum() # ------------------------------------------------------------------ # Baseline generation diff --git a/pyhealth/interpret/methods/integrated_gradients.py b/pyhealth/interpret/methods/integrated_gradients.py index a529a6f3f..249f5a5e9 100644 --- a/pyhealth/interpret/methods/integrated_gradients.py +++ b/pyhealth/interpret/methods/integrated_gradients.py @@ -217,9 +217,11 @@ def attribute( the integral. If None, uses self.steps (set during initialization). More steps lead to better approximation but slower computation. - target_class_idx: Target class index for attribution - computation. If None, uses the predicted class (argmax of - model output). + target_class_idx: Target class index for attribution. + For binary classification (single logit output), this is + a no-op because there is only one output. For multi-class + or multi-label, specifies which class to explain. If None, + uses the argmax of model output. **kwargs: Input data dictionary from a dataloader batch containing: - Feature keys (e.g., 'conditions', 'procedures'): @@ -324,35 +326,7 @@ def attribute( with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] - mode = self._prediction_mode() - if mode == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif mode == "multiclass": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = torch.argmax(base_logits, dim=-1) - target = F.one_hot( - target, num_classes=base_logits.shape[-1] - ).float() - elif mode == "multilabel": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = (torch.sigmoid(base_logits) > 0.5).float() - else: - raise ValueError( - "Unsupported prediction mode for Integrated Gradients attribution." - ) + target_indices = self._resolve_target_indices(base_logits, target_class_idx) # Generate baselines if baseline is None: @@ -405,7 +379,7 @@ def attribute( xs=values, bs=baselines, steps=steps, - target=target, + target_indices=target_indices, ) return self._map_to_input_shapes(attributions, shapes) @@ -419,7 +393,7 @@ def _integrated_gradients( xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], steps: int, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> Dict[str, torch.Tensor]: """Compute integrated gradients via Riemann sum approximation. @@ -438,8 +412,7 @@ def _integrated_gradients( xs: Input values (embedded if use_embeddings=True). bs: Baseline values (embedded if use_embeddings=True). steps: Number of interpolation steps. - target: Target tensor for computing the scalar output to - differentiate (one-hot for multiclass, class idx for binary). + target_indices: [batch] tensor of target class indices. Returns: Dictionary mapping feature keys to attribution tensors. @@ -513,7 +486,7 @@ def _integrated_gradients( logits = output["logit"] # Compute target output and backward pass - target_output = self._compute_target_output(logits, target) + target_output = self._compute_target_output(logits, target_indices) self.model.zero_grad() target_output.backward(retain_graph=True) @@ -550,41 +523,23 @@ def _integrated_gradients( def _compute_target_output( self, logits: torch.Tensor, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Compute scalar target output for backpropagation. - Creates a differentiable scalar from the model logits that, - when differentiated, gives the gradient of the target class - logit w.r.t. the input. + Selects the target-class logit for each sample and sums over + the batch to produce a single differentiable scalar. Args: - logits: Model output logits, shape [batch, num_classes] or - [batch, 1]. - target: Target tensor. For binary: [batch] or [1] with 0/1 - class indices. For multiclass/multilabel: [batch, num_classes] - one-hot or multi-hot tensor. + logits: Model output logits, shape [batch, num_classes]. + target_indices: [batch] tensor of target class indices. Returns: Scalar tensor for backpropagation. """ - target_f = target.to(logits.device).float() - mode = self._prediction_mode() - - if mode == "binary": - # target shape: [1] or [batch, 1] with 0/1 values - # Convert to signs: 0 -> -1, 1 -> 1 - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(-1) - target_f = target_f.expand_as(logits) - signs = 2.0 * target_f - 1.0 - return (signs * logits).sum() - else: - # multiclass or multilabel: target is one-hot/multi-hot - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(0) - target_f = target_f.expand_as(logits) - return (target_f * logits).sum() + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1).sum() # ------------------------------------------------------------------ # Baseline generation diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 5176bfeaf..4e407fccf 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -250,25 +250,7 @@ def attribute( # Extract and prepare inputs base_logits = self.model.forward(**inputs)["logit"] - # Enforce target class selection for multi-class models to avoid class flipping - if self._prediction_mode() == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif self._prediction_mode() == "multiclass": - if target_class_idx is not None: - target = torch.nn.functional.one_hot(torch.tensor(target_class_idx, device=device), num_classes=base_logits.shape[-1]) - else: - target = torch.argmax(base_logits, dim=-1) - target = torch.nn.functional.one_hot(target, num_classes=base_logits.shape[-1]) - elif self._prediction_mode() == "multilabel": - if target_class_idx is not None: - target = torch.nn.functional.one_hot(torch.tensor(target_class_idx, device=device), num_classes=base_logits.shape[-1]) - else: - target = torch.sigmoid(base_logits) > 0.5 - else: - raise ValueError("Unsupported prediction mode for LIME attribution.") + target_indices = self._resolve_target_indices(base_logits, target_class_idx) if baseline is None: baselines = self._generate_baseline(values, use_embeddings=self.use_embeddings) @@ -309,7 +291,7 @@ def attribute( xs=values, bs=baselines, n_features=n_features, - target=target, + target_indices=target_indices, ) return self._map_to_input_shapes(out, shapes) @@ -323,7 +305,7 @@ def _compute_lime( xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], n_features: dict[str, int], - target: torch.Tensor, + target_indices: torch.Tensor, ) -> Dict[str, torch.Tensor]: """Compute LIME coefficients using interpretable linear model. @@ -376,7 +358,7 @@ def _compute_lime( pred = self._evaluate_sample( inputs, perturb, - target, + target_indices, ) # Create perturbed sample for each batch item @@ -475,15 +457,19 @@ def _evaluate_sample( self, inputs: dict[str, tuple[torch.Tensor, ...]], perturb: dict[str, torch.Tensor], - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Evaluate model prediction for a perturbed sample. + Returns the model's prediction for the target class, so the + weighted linear regression approximates the model's actual + output (not a distance to a label). + Args: inputs: Original input tuples (used for non-value fields like time/mask). perturb: Perturbed sample tensors. Token features are already embedded; continuous features are still in raw space. - target: Target class tensor. + target_indices: [batch] tensor of target class indices. Returns: Model prediction for the perturbed sample, shape (batch_size, ). @@ -523,8 +509,10 @@ def _evaluate_sample( # model's regular forward pass handle embedding internally. logits = self.model.forward(**inputs)["logit"] - # Reduce to [batch_size, ] by taking absolute difference from target class logit - return (target - logits).abs().mean(dim=tuple(range(1, logits.ndim))) + # Extract the target class prediction (logits, not label distances) + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1) def _compute_similarity( self, diff --git a/pyhealth/interpret/methods/shap.py b/pyhealth/interpret/methods/shap.py index 46d40a977..df52ea732 100644 --- a/pyhealth/interpret/methods/shap.py +++ b/pyhealth/interpret/methods/shap.py @@ -220,33 +220,7 @@ def attribute( # Extract and prepare inputs base_logits = self.model.forward(**inputs)["logit"] - # Enforce target class selection for multi-class models to avoid class flipping - if self._prediction_mode() == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif self._prediction_mode() == "multiclass": - if target_class_idx is not None: - target = torch.nn.functional.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ) - else: - target = torch.argmax(base_logits, dim=-1) - target = torch.nn.functional.one_hot( - target, num_classes=base_logits.shape[-1] - ) - elif self._prediction_mode() == "multilabel": - if target_class_idx is not None: - target = torch.nn.functional.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ) - else: - target = torch.sigmoid(base_logits) > 0.5 - else: - raise ValueError("Unsupported prediction mode for SHAP attribution.") + target_indices = self._resolve_target_indices(base_logits, target_class_idx) if baseline is None: baselines = self._generate_background_samples( @@ -295,7 +269,7 @@ def attribute( xs=values, bs=baselines, n_features=n_features, - target=target, + target_indices=target_indices, ) return self._map_to_input_shapes(out, shapes) @@ -309,7 +283,7 @@ def _compute_kernel_shap( xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], n_features: dict[str, int], - target: torch.Tensor, + target_indices: torch.Tensor, ) -> Dict[str, torch.Tensor]: """Compute SHAP values using the Kernel SHAP approximation method. @@ -325,7 +299,7 @@ def _compute_kernel_shap( xs: Dictionary of input values (or embeddings). bs: Dictionary of baseline values (or embeddings). n_features: Dictionary mapping feature keys to feature counts. - target: Target tensor for prediction comparison. + target_indices: [batch] tensor of target class indices. Returns: Dictionary mapping feature keys to SHAP value tensors. @@ -353,7 +327,7 @@ def _compute_kernel_shap( coalition, keys, n_features, batch_size ) perturb = self._create_perturbed_sample(xs, bs, gates) - pred = self._evaluate_sample(inputs, perturb, target) + pred = self._evaluate_sample(inputs, perturb, target_indices) coalition_vectors.append(coalition.float()) coalition_preds.append(pred.detach()) @@ -374,7 +348,7 @@ def _compute_kernel_shap( coalition, keys, n_features, batch_size ) perturb = self._create_perturbed_sample(xs, bs, gates) - pred = self._evaluate_sample(inputs, perturb, target) + pred = self._evaluate_sample(inputs, perturb, target_indices) coalition_vectors.append(coalition.float()) coalition_preds.append(pred.detach()) @@ -480,7 +454,7 @@ def _evaluate_sample( self, inputs: dict[str, tuple[torch.Tensor, ...]], perturb: dict[str, torch.Tensor], - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Evaluate model prediction for a perturbed sample. @@ -492,9 +466,7 @@ def _evaluate_sample( Args: inputs: Original input tuples from the dataloader. perturb: Dictionary of perturbed value tensors. - target: Target tensor used to select which class prediction to - return. For binary this is a 0/1 scalar or (batch,1) tensor; - for multiclass/multilabel it is a one-hot vector. + target_indices: [batch] tensor of target class indices. Returns: Target-class prediction scalar per batch item, shape (batch_size,). @@ -537,56 +509,9 @@ def _evaluate_sample( # model's regular forward pass handle embedding internally. logits = self.model.forward(**inputs)["logit"] - return self._extract_target_prediction(logits, target) - - def _extract_target_prediction( - self, - logits: torch.Tensor, - target: torch.Tensor, - ) -> torch.Tensor: - """Extract the model's prediction for the target class. - - Kernel SHAP decomposes f(x) ≈ φ₀ + Σ φᵢ zᵢ via weighted least squares. - Using **raw logits** (unbounded) rather than probabilities (bounded - [0, 1]) is critical: sigmoid compression squashes coalition differences - in the saturated regions, producing uniformly small SHAP values and - degraded feature rankings. - - Args: - logits: Raw model logits, shape (batch_size, n_classes) or - (batch_size, 1). - target: Target indicator. Binary: scalar/tensor with 0 or 1. - Multiclass: one-hot tensor. Multilabel: multi-hot tensor. - - Returns: - Scalar prediction per batch item, shape (batch_size,). - """ - mode = self._prediction_mode() - - if mode == "binary": - # Use raw logit — not sigmoid probability — to preserve the - # dynamic range that Kernel SHAP's linear decomposition needs. - logit = logits.squeeze(-1) # (batch,) - t = target.float() - if t.dim() > 1: - t = t.squeeze(-1) - # target=1 → logit (higher logit ⇒ more positive class) - # target=0 → −logit (higher value ⇒ more negative class) - return t * logit + (1 - t) * (-logit) - - elif mode == "multiclass": - # target is one-hot; dot-product extracts the target-class logit - return (target.float() * logits).sum(dim=-1) # (batch,) - - elif mode == "multilabel": - # target is multi-hot; average logits over active labels - t = target.float() - n_active = t.sum(dim=-1).clamp(min=1) # avoid div-by-zero - return (t * logits).sum(dim=-1) / n_active # (batch,) - - else: - # regression or unknown — just return the logit - return logits.squeeze(-1) + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1) # ------------------------------------------------------------------ # Weighted least squares solver diff --git a/pyhealth/metrics/interpretability/__init__.py b/pyhealth/metrics/interpretability/__init__.py index d0de26057..13bb0831a 100644 --- a/pyhealth/metrics/interpretability/__init__.py +++ b/pyhealth/metrics/interpretability/__init__.py @@ -4,7 +4,12 @@ from .comprehensiveness import ComprehensivenessMetric from .evaluator import Evaluator, evaluate_attribution from .sufficiency import SufficiencyMetric -from .utils import create_validity_mask, get_model_predictions +from .utils import ( + SampleClass, + SampleFilterFn, + get_model_predictions, + threshold_sample_filter, +) __all__ = [ "ComprehensivenessMetric", @@ -12,7 +17,10 @@ "RemovalBasedMetric", "Evaluator", "evaluate_attribution", + # Sample classification + "SampleClass", + "SampleFilterFn", + "threshold_sample_filter", # Utility functions "get_model_predictions", - "create_validity_mask", ] diff --git a/pyhealth/metrics/interpretability/base.py b/pyhealth/metrics/interpretability/base.py index d3dc120f9..ef388402b 100644 --- a/pyhealth/metrics/interpretability/base.py +++ b/pyhealth/metrics/interpretability/base.py @@ -11,7 +11,11 @@ from pyhealth.models import BaseModel -from .utils import create_validity_mask, get_model_predictions +from .utils import ( + SampleClass, + SampleFilterFn, + get_model_predictions, +) class RemovalBasedMetric(ABC): @@ -30,8 +34,15 @@ class RemovalBasedMetric(ABC): - 'mean': Set ablated features to feature mean across batch - 'noise': Add Gaussian noise to ablated features Default: 'zero'. - positive_threshold: Threshold for positive class in binary - classification. Default: 0.5. + sample_filter: A callable that classifies each sample for evaluation. + Signature: (class_probs, classifier_type) -> sample_classes + where class_probs has shape (batch_size,) and contains the + probability for the predicted class (sigmoid/softmax output + with target class already applied), and sample_classes is a + tensor of SampleClass values. + - SampleClass.POSITIVE: evaluate with attributions as-is + - SampleClass.NEGATIVE: evaluate with negated attributions + - SampleClass.IGNORE: exclude from evaluation """ def __init__( @@ -39,12 +50,13 @@ def __init__( model: BaseModel, percentages: List[float] = [1, 5, 10, 20, 50], ablation_strategy: str = "zero", - positive_threshold: float = 0.5, + *, + sample_filter: SampleFilterFn, ): self.model = model self.percentages = percentages self.ablation_strategy = ablation_strategy - self._positive_threshold = positive_threshold + self._sample_filter = sample_filter self.model.eval() # Detect classifier type from model @@ -111,7 +123,7 @@ def _detect_classifier_type(self): self.num_classes = 2 print("[RemovalBasedMetric] Detected BINARY classifier") print(" - Output shape: [batch, 1] with P(class=1)") - print(" - Only evaluates positive predictions (>=threshold)") + print(" - Evaluates both positive and negative predictions") elif mode == "multiclass": self.classifier_type = "multiclass" # Get num_classes from processor @@ -365,42 +377,38 @@ def compute( samples have value 0. Note: - For binary classifiers, the valid_mask indicates samples with - P(class=1) >= threshold (default 0.5). Use this mask to filter - scores during averaging or analysis. + For binary classifiers, all samples are evaluated + (both positive and negative predictions). For class 0 + predictions, attributions are negated internally so that + feature importance is measured relative to the predicted + class. """ # Get original predictions (returns 3 values) - original_probs, pred_classes, original_class_probs = get_model_predictions( + y_probs, target_class_idx, sample_class = get_model_predictions( model=self.model, inputs=inputs, classifier_type=self.classifier_type, - pred_classes=predicted_class, - positive_threshold=self._positive_threshold, + sample_filter=self._sample_filter, ) - - if predicted_class is not None: - pred_classes = predicted_class - - batch_size = original_probs.shape[0] - - # Create validity mask using helper - valid_mask = create_validity_mask( - original_probs, - self.classifier_type, - self._positive_threshold, - ) - - # For binary: determine which samples to evaluate - if self.classifier_type == "binary": - positive_mask = pred_classes == 1 - num_positive = positive_mask.sum().item() - num_negative = (~positive_mask).sum().item() - else: - positive_mask = torch.ones( - batch_size, dtype=torch.bool, device=original_probs.device - ) - num_positive = batch_size - num_negative = 0 + + batch_size = y_probs.shape[0] + + # Validity mask: IGNORE samples excluded + val_mask = sample_class != SampleClass.IGNORE + + # For NEGATIVE samples, negate attributions so that + # "top features" become those most important for the predicted + # class (features with low class-1 attribution support class 0). + neg_mask = sample_class == SampleClass.NEGATIVE + if neg_mask.any(): + attributions = { + key: torch.where( + neg_mask.view(-1, *([1] * (attr.dim() - 1))), + -attr, + attr, + ) + for key, attr in attributions.items() + } # Debug output (if requested and returning per percentage) if debug and return_per_percentage: @@ -411,37 +419,21 @@ def compute( print(f"Classifier type: {self.classifier_type}") if self.classifier_type == "binary": - print(f"Positive class samples: {num_positive}") - print(f"Negative class samples: {num_negative}") - print("NOTE: Only computing metrics for POSITIVE class") - - print(f"Original probs shape: {original_probs.shape}") - print(f"Predicted classes: {pred_classes.tolist()}") - - if self.classifier_type == "binary": - print("\nOriginal probabilities P(class=1):") - for i, prob in enumerate(original_probs): - status = "EVAL" if positive_mask[i] else "SKIP" - print(f" Sample {i} [{status}]: {prob.item():.6f}") - else: - print("\nOriginal probabilities (all classes):") - for i, probs in enumerate(original_probs): - print(f" Sample {i}: {probs.tolist()}") + print(f"Positive class samples: {(sample_class == SampleClass.POSITIVE).sum().item()}") + print(f"Negative class samples: {(sample_class == SampleClass.NEGATIVE).sum().item()}") + print("NOTE: Evaluating BOTH positive and negative predictions") print("\nOriginal probs for predicted class:") - for i, prob in enumerate(original_class_probs): - if self.classifier_type == "binary": - status = "EVAL" if positive_mask[i] else "SKIP" - print(f" Sample {i} [{status}]: {prob.item():.6f}") - else: - print(f" Sample {i}: {prob.item():.6f}") + for i, prob in enumerate(y_probs): + cls = target_class_idx[i].item() + print(f" Sample {i} [class={cls}]: {prob.item():.6f}") # Store results per percentage if return_per_percentage: results = {} else: # Accumulator for averaging - metric_scores = torch.zeros(batch_size, device=original_probs.device) + metric_scores = torch.zeros(batch_size, device=y_probs.device) # Compute metrics across all percentages for percentage in self.percentages: @@ -452,18 +444,24 @@ def compute( ablated_inputs = self._create_ablated_inputs(inputs, masks) # Get predictions on ablated inputs - ablated_probs, _, ablated_class_probs = get_model_predictions( + ablated_probs, _, _ = get_model_predictions( model=self.model, inputs=ablated_inputs, - pred_classes=pred_classes, # Use same predicted classes from original to avoid shifts + target_class_idx=target_class_idx, # Use same predicted classes from original to avoid shifts + sample_class=sample_class, # Use same sample classes to ensure consistency classifier_type=self.classifier_type, - positive_threshold=self._positive_threshold, ) # Compute probability drop - prob_drop = torch.zeros(batch_size, device=original_probs.device) - prob_drop[positive_mask] = ( - original_class_probs[positive_mask] - ablated_class_probs[positive_mask] + original_class_probs = y_probs + original_class_probs[neg_mask] = -original_class_probs[neg_mask] + + ablated_class_probs = ablated_probs + ablated_class_probs[neg_mask] = -ablated_class_probs[neg_mask] + + prob_drop = torch.zeros(batch_size, device=y_probs.device) + prob_drop[val_mask] = ( + original_class_probs[val_mask] - ablated_class_probs[val_mask] ) # Debug output for this percentage @@ -476,8 +474,8 @@ def compute( if self.classifier_type == "binary": print("\nAblated probabilities P(class=1):") for i, prob in enumerate(ablated_probs): - status = "EVAL" if positive_mask[i] else "SKIP" - print(f" Sample {i} [{status}]: {prob.item():.6f}") + cls = target_class_idx[i].item() + print(f" Sample {i} [class={cls}]: {prob.item():.6f}") else: print("\nAblated probabilities (all classes):") for i, probs in enumerate(ablated_probs): @@ -485,18 +483,16 @@ def compute( print("\nProbability drops (original - ablated):") for i, drop in enumerate(prob_drop): - if drop == 0 and not positive_mask[i]: - print(f" Sample {i} [SKIP]: " f"0.000000 (negative class)") - else: - orig = original_class_probs[i].item() - abl = ablated_class_probs[i].item() - print( - f" Sample {i} [EVAL]: {drop.item():.6f} " - f"({orig:.6f} - {abl:.6f})" - ) + orig = original_class_probs[i].item() + abl = ablated_class_probs[i].item() + cls = target_class_idx[i].item() + print( + f" Sample {i} [class={cls}]: {drop.item():.6f} " + f"({orig:.6f} - {abl:.6f})" + ) # Check for unexpected negative values - evaluated_drops = prob_drop[positive_mask] + evaluated_drops = prob_drop[val_mask] neg_mask = evaluated_drops < 0 if neg_mask.any(): neg_count = neg_mask.sum().item() @@ -511,15 +507,15 @@ def compute( print(" - Attribution quality may be poor") if return_per_percentage: - results[percentage] = prob_drop + results[percentage] = prob_drop # type: ignore else: # Accumulate for averaging - metric_scores = metric_scores + prob_drop + metric_scores = metric_scores + prob_drop # type: ignore # Return appropriate format if return_per_percentage: - return results + return results # type: ignore else: # Average across percentages - metric_scores = metric_scores / len(self.percentages) - return metric_scores, valid_mask + metric_scores = metric_scores / len(self.percentages) # type: ignore + return metric_scores, val_mask diff --git a/pyhealth/metrics/interpretability/evaluator.py b/pyhealth/metrics/interpretability/evaluator.py index fe6f99f61..1344358eb 100644 --- a/pyhealth/metrics/interpretability/evaluator.py +++ b/pyhealth/metrics/interpretability/evaluator.py @@ -4,7 +4,8 @@ using removal-based metrics like Comprehensiveness and Sufficiency. """ -from typing import Dict, List +from typing import Dict, List, Optional +import warnings import torch @@ -12,6 +13,7 @@ from .comprehensiveness import ComprehensivenessMetric from .sufficiency import SufficiencyMetric +from .utils import SampleClass, SampleFilterFn, threshold_sample_filter class Evaluator: @@ -30,17 +32,46 @@ class Evaluator: - 'mean': Set ablated features to feature mean across batch - 'noise': Add Gaussian noise to ablated features Default: 'zero'. - positive_threshold: Threshold for positive class in binary - classification. Samples with P(class=1) >= threshold are - considered valid for evaluation. Default: 0.5. + sample_filter: A callable that classifies each sample for evaluation. + Signature: (class_probs, classifier_type) -> sample_classes + where class_probs has shape (batch_size,) and contains the + class probability used for filtering. For binary single-logit + models, this is ``P(class=1)``. For multiclass/multilabel + models, this is the gathered target-class probability. + ``sample_classes`` is a tensor of SampleClass values: + - SampleClass.POSITIVE: evaluate with attributions as-is + - SampleClass.NEGATIVE: evaluate with negated attributions + - SampleClass.IGNORE: exclude from evaluation + If None, uses default_sample_filter. + positive_threshold: .. deprecated:: + This parameter is deprecated and will be removed in a future + release. Use ``sample_filter`` with + :func:`threshold_sample_filter` instead. + Threshold for positive class in binary classification. + Default: None. Examples: >>> from pyhealth.models import StageNet >>> from pyhealth.metrics.interpretability import Evaluator + >>> from pyhealth.metrics.interpretability.utils import ( + ... SampleClass, + ... threshold_sample_filter, + ... ) >>> - >>> # Initialize evaluator + >>> # Initialize evaluator with default filter >>> evaluator = Evaluator(model) >>> + >>> # Initialize with custom filter that ignores low-confidence + >>> def confident_filter(class_probs, classifier_type): + ... batch_size = class_probs.shape[0] + ... result = torch.full( + ... (batch_size,), SampleClass.POSITIVE, + ... dtype=torch.long, device=class_probs.device, + ... ) + ... result[class_probs < 0.6] = SampleClass.IGNORE + ... return result + >>> evaluator = Evaluator(model, sample_filter=confident_filter) + >>> >>> # Evaluate on a single batch >>> inputs = {'conditions': torch.randn(32, 50)} >>> attributions = {'conditions': torch.randn(32, 50)} @@ -61,24 +92,53 @@ def __init__( model: BaseModel, percentages: List[float] = [1, 5, 10, 20, 50], ablation_strategy: str = "zero", - positive_threshold: float = 0.5, + sample_filter: Optional[SampleFilterFn] = None, + positive_threshold: Optional[float] = None, ): self.model = model self.percentages = percentages self.ablation_strategy = ablation_strategy self.positive_threshold = positive_threshold + + # Resolve the effective sample filter: + # 1. explicit sample_filter wins + # 2. positive_threshold → threshold_sample_filter(positive_threshold) + # 3. fallback → default (threshold_sample_filter(0.5)) + if sample_filter is not None: + if positive_threshold is not None: + warnings.warn( + "Both sample_filter and positive_threshold were given. " + "sample_filter takes precedence; positive_threshold is " + "ignored.", + UserWarning, + stacklevel=2, + ) + resolved_filter = sample_filter + elif positive_threshold is not None: + warnings.warn( + "positive_threshold is deprecated and will be removed in a " + "future release. Use sample_filter with " + "threshold_sample_filter() instead.", + DeprecationWarning, + stacklevel=2, + ) + resolved_filter = threshold_sample_filter(positive_threshold) + else: + resolved_filter = threshold_sample_filter(0.5) + + self.sample_filter = resolved_filter self.metrics = { "comprehensiveness": ComprehensivenessMetric( model, percentages=percentages, ablation_strategy=ablation_strategy, - positive_threshold=positive_threshold, + sample_filter=resolved_filter, ), "sufficiency": SufficiencyMetric( model, percentages=percentages, ablation_strategy=ablation_strategy, - positive_threshold=positive_threshold, + sample_filter=resolved_filter, ), } @@ -113,8 +173,9 @@ def evaluate( Example: {'comprehensiveness': {10: tensor(...), 20: ...}} Note: - For binary classifiers, valid_mask indicates samples with - P(class=1) >= threshold. Use: scores[valid_mask].mean() + For binary classifiers, all samples are evaluated + (both positive and negative predictions). + Use: scores[valid_mask].mean() Examples: >>> # Default: averaged scores @@ -166,16 +227,18 @@ def evaluate_attribution( Returns: Dictionary mapping metric names to their average scores - across the entire dataset. For binary classifiers, only - positive class (predicted class=1) samples are included - in the average. + across the entire dataset. Samples marked ``IGNORE`` by the + configured ``sample_filter`` are excluded from the average. Example: {'comprehensiveness': 0.345, 'sufficiency': 0.123} Note: - For binary classifiers, negative class (predicted class=0) - samples are excluded from the average, as ablation metrics - are not meaningful for the default/null class. + For binary classifiers, both positive and negative samples can + be evaluated. Negative samples are handled by negating the + attribution scores before top-feature selection, which makes + the probability drop equivalent to the drop in confidence for + class 0. Use ``sample_filter`` to include or exclude whichever + subsets you want in the dataset average. Examples: >>> from pyhealth.interpret.methods import IntegratedGradients @@ -245,13 +308,15 @@ def evaluate_attribution( ) # Accumulate statistics incrementally (no tensor storage) + first_metric = metrics[0] + batch_size = len(batch_results[first_metric][0]) + total_samples += batch_size + for metric_name in metrics: scores, valid_mask = batch_results[metric_name] # Track statistics efficiently - batch_size = len(scores) num_valid = valid_mask.sum().item() - total_samples += batch_size total_valid[metric_name] += num_valid # Update running sum (valid scores only) @@ -325,20 +390,19 @@ def evaluate_attribution( print(" * Important features not correctly identified") print(" * Consider checking attribution method") - valid_ratio = sum(total_valid.values()) / (len(metrics) * total_samples) - if valid_ratio < 0.1: + valid_ratio = sum(total_valid.values()) / (len(metrics) * total_samples) if total_samples > 0 else 0 + if valid_ratio < 0.1 and total_samples > 0: print(f"\n⚠ WARNING: Only {valid_ratio*100:.1f}% valid samples") print(" - Most predictions are negative class") print(" - Consider:") print(" * Checking model predictions distribution") - print(" * Adjusting positive_threshold parameter") + print(" * Adjusting sample_filter to include more samples") print(" * Using balanced test set") print(f"{'='*70}\n") return results - # Functional API (wraps Evaluator for convenience) def evaluate_attribution( model: BaseModel, @@ -347,7 +411,8 @@ def evaluate_attribution( metrics: List[str] = ["comprehensiveness", "sufficiency"], percentages: List[float] = [1, 5, 10, 20, 50], ablation_strategy: str = "zero", - positive_threshold: float = 0.5, + sample_filter: Optional[SampleFilterFn] = None, + positive_threshold: Optional[float] = None, ) -> Dict[str, float]: """Evaluate an attribution method across a dataset (functional API). @@ -371,27 +436,37 @@ def evaluate_attribution( - 'mean': Set ablated features to feature mean across batch - 'noise': Add Gaussian noise to ablated features Default: 'zero'. - positive_threshold: Threshold for positive class in binary - classification. Samples with P(class=1) >= threshold are - considered valid for evaluation. Default: 0.5. + sample_filter: A callable that classifies each sample for + evaluation. Signature: + (class_probs, classifier_type) -> sample_classes + where class_probs has shape (batch_size,) and contains the + probability for the predicted class (sigmoid/softmax output + with target class already applied), and sample_classes is a + tensor of SampleClass values: + - SampleClass.POSITIVE: evaluate with attributions as-is + - SampleClass.NEGATIVE: evaluate with negated attributions + - SampleClass.IGNORE: exclude from evaluation + If None, uses default_sample_filter. + positive_threshold: .. deprecated:: + This parameter is deprecated and will be removed in a future + release. Use ``sample_filter`` with + :func:`threshold_sample_filter` instead. + Threshold for positive class in binary classification. + Default: None. Returns: Dictionary mapping metric names to their average scores across the entire dataset. Averaging uses mask-based filtering - to include only valid samples (positive predictions for binary). + to exclude IGNORE samples. Example: {'comprehensiveness': 0.345, 'sufficiency': 0.123} - Note: - For binary classifiers, only samples with P(class=1) >= threshold - are included in the average, as ablation metrics are not - meaningful for negative predictions. - Examples: >>> from pyhealth.interpret.methods import IntegratedGradients >>> from pyhealth.metrics.interpretability import ( ... evaluate_attribution ... ) + >>> from pyhealth.metrics.interpretability.utils import SampleClass >>> >>> # Simple one-off evaluation >>> ig = IntegratedGradients(model, use_embeddings=True) @@ -402,10 +477,18 @@ def evaluate_attribution( ... ) >>> print(f"Comprehensiveness: {results['comprehensiveness']:.4f}") >>> - >>> # Custom threshold for binary classification + >>> # Custom filter to ignore uncertain predictions + >>> def ignore_uncertain(class_probs, classifier_type): + ... batch_size = class_probs.shape[0] + ... result = torch.full( + ... (batch_size,), SampleClass.POSITIVE, + ... dtype=torch.long, device=class_probs.device, + ... ) + ... result[class_probs < 0.7] = SampleClass.IGNORE + ... return result >>> results = evaluate_attribution( ... model, test_loader, ig, - ... positive_threshold=0.7 # Only evaluate high-confidence + ... sample_filter=ignore_uncertain, ... ) >>> >>> # For comparing multiple methods efficiently, use Evaluator: @@ -420,6 +503,7 @@ def evaluate_attribution( model, percentages=percentages, ablation_strategy=ablation_strategy, + sample_filter=sample_filter, positive_threshold=positive_threshold, ) return evaluator.evaluate_attribution(dataloader, method, metrics=metrics) diff --git a/pyhealth/metrics/interpretability/utils.py b/pyhealth/metrics/interpretability/utils.py index 5206bbf1d..3f10814cb 100644 --- a/pyhealth/metrics/interpretability/utils.py +++ b/pyhealth/metrics/interpretability/utils.py @@ -4,7 +4,8 @@ metrics to avoid code duplication and improve maintainability. """ -from typing import Dict, Optional, Tuple +from enum import IntEnum +from typing import Callable, Dict, Optional, Tuple import torch import torch.nn.functional as F @@ -12,12 +13,83 @@ from pyhealth.models import BaseModel +class SampleClass(IntEnum): + """Classification of how a sample should be treated during evaluation. + + Attributes: + POSITIVE: Evaluate sample with attributions as-is. + Used for predicted positive class in binary, or all + samples in multiclass/multilabel. + NEGATIVE: Evaluate sample with negated attributions. + Used for predicted negative class in binary classification, + where feature importance is measured relative to the + predicted class (class 0). + IGNORE: Exclude sample from evaluation entirely. + Useful for filtering out low-confidence predictions or + samples that should not contribute to the metric. + """ + + POSITIVE = 1 + NEGATIVE = -1 + IGNORE = 0 + + +# Type alias for sample filter functions. +# Signature: (y_probs, classifier_type) -> sample_classes +# y_probs has shape (batch_size,). For binary single-logit models this is +# P(class=1); for multiclass/multilabel models this is the gathered +# target-class probability. +SampleFilterFn = Callable[[torch.Tensor, str], torch.Tensor] + + +def threshold_sample_filter(threshold: float = 0.5) -> SampleFilterFn: + """Create a filter based on a probability threshold. + + For binary and multilabel classifiers, samples whose predicted-class + probability is at or above ``threshold`` are marked POSITIVE; all + others are marked IGNORE. + + For multiclass classifiers, all samples are marked POSITIVE + (the argmax class always has a well-defined probability). + + Args: + threshold: Minimum predicted-class probability to include + the sample. Default: 0.5. + + Returns: + A sample filter function. + + Examples: + >>> # Create a filter that ignores uncertain predictions + >>> my_filter = threshold_sample_filter(0.7) + >>> evaluator = Evaluator(model, sample_filter=my_filter) + """ + + def filter_fn( + y_probs: torch.Tensor, + classifier_type: str, + ) -> torch.Tensor: + batch_size = y_probs.shape[0] + result = torch.full( + (batch_size,), + SampleClass.POSITIVE, + dtype=torch.long, + device=y_probs.device, + ) + if classifier_type in ("binary", "multilabel"): + result[y_probs < threshold] = SampleClass.IGNORE + return result + + return filter_fn + + def get_model_predictions( model: BaseModel, inputs: Dict[str, torch.Tensor], classifier_type: str, - pred_classes: Optional[torch.Tensor] = None, - positive_threshold: float = 0.5, + sample_filter: Optional[SampleFilterFn] = None, + sample_class: Optional[torch.Tensor] = None, + target_class_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get model predictions, probabilities, and class-specific probabilities. @@ -25,18 +97,22 @@ def get_model_predictions( model: PyHealth BaseModel that returns dict with 'y_prob' or 'logit' inputs: Model inputs dict classifier_type: One of 'binary', 'multiclass', 'multilabel', 'unknown' - pred_classes: (Optional) Pre-computed predicted classes, this would ensure ablated runs + target_class_idx: (Optional) Pre-computed target class indices, this would ensure ablated runs are consistent with original predictions. If None, will compute from model outputs. - positive_threshold: Threshold for binary classification (default: 0.5) + sample_filter: A callable that classifies each sample for evaluation. + Signature: (class_probs, classifier_type) -> sample_classes + where class_probs has shape (batch_size,). For binary + single-logit models this is ``P(class=1)``; otherwise it is + the gathered target-class probability. ``sample_classes`` is + a tensor of SampleClass values. Returns: - Tuple of (y_prob, pred_classes, class_probs): + Tuple of (y_prob, target_class_idx, sample_classes): - y_prob: All class probabilities - Binary: shape (batch_size, 1), values are P(class=1) - Multiclass: shape (batch_size, num_classes) - - pred_classes: Predicted class indices, shape (batch_size,) - - class_probs: Probability for each sample's predicted class, - shape (batch_size,) + - target_class_idx: Target class indices, shape (batch_size,) + - sample_classes: SampleClass values for each sample, shape (batch_size,) """ with torch.no_grad(): outputs = model(**inputs) @@ -46,7 +122,7 @@ def get_model_predictions( y_prob = outputs["y_prob"] elif "logit" in outputs: logits = outputs["logit"] - if classifier_type == "binary": + if classifier_type in ["binary", "multilabel"]: y_prob = torch.sigmoid(logits) else: y_prob = F.softmax(logits, dim=-1) @@ -57,45 +133,22 @@ def get_model_predictions( if y_prob.dim() == 1: y_prob = y_prob.unsqueeze(-1) - # Get predicted classes based on classifier type - if classifier_type == "binary": - # For binary: class 1 if P(class=1) >= threshold, else 0 - pred_classes = (y_prob.squeeze(-1) >= positive_threshold).long() if pred_classes is None else pred_classes - # For binary, class_probs is P(class=1) - class_probs = y_prob.squeeze(-1) - else: - # For multiclass/multilabel: argmax - pred_classes = torch.argmax(y_prob, dim=-1) if pred_classes is None else pred_classes - # Gather probabilities for predicted classes - class_probs = y_prob.gather(1, pred_classes.unsqueeze(1)).squeeze(1) - assert pred_classes is not None, "pred_classes should have been set either by input or computation." - - return y_prob, pred_classes, class_probs - + if target_class_idx is None: + target_class_idx = torch.argmax(y_prob, dim=-1) + + y_prob = y_prob.gather( + dim=-1, + index=target_class_idx.unsqueeze(-1), + ).squeeze(-1) + + # Apply sample filter + if sample_class is None: + if sample_filter is None: + raise ValueError("sample_filter must be provided if sample_class is None") + sample_class = sample_filter(y_prob, classifier_type) + + y_prob[sample_class == SampleClass.IGNORE] = 0.0 # Set ignored samples' probs to 0 + target_class_idx[sample_class == SampleClass.IGNORE] = 0 # Mark ignored samples with invalid class index + + return y_prob, target_class_idx, sample_class -def create_validity_mask( - y_prob: torch.Tensor, - classifier_type: str, - positive_threshold: float = 0.5, -) -> torch.Tensor: - """Create a mask indicating which samples are valid for metric computation. - - For binary classifiers, only positive predictions (P(class=1) >= threshold) - are considered valid. For multiclass/multilabel, all samples are valid. - - Args: - y_prob: Model probability outputs - classifier_type: One of 'binary', 'multiclass', 'multilabel' - positive_threshold: Threshold for binary classification (default: 0.5) - - Returns: - Boolean tensor of shape (batch_size,) where True indicates valid samples - """ - batch_size = y_prob.shape[0] - - if classifier_type == "binary": - # For binary: valid = P(class=1) >= threshold - return y_prob.squeeze(-1) >= positive_threshold - else: - # For multiclass/multilabel: all samples are valid - return torch.ones(batch_size, dtype=torch.bool, device=y_prob.device) diff --git a/tests/core/test_deeplift.py b/tests/core/test_deeplift.py index 8fe832d82..29dbf8328 100644 --- a/tests/core/test_deeplift.py +++ b/tests/core/test_deeplift.py @@ -88,15 +88,15 @@ def test_basic_attribution(self): self.assertIsInstance(attributions["procedures"], torch.Tensor) def test_attribution_with_target_class(self): - """Test attribution computation with specific target class.""" + """For binary (single logit), target_class_idx is a no-op.""" dl = DeepLift(self.model) data_batch = next(iter(self.test_loader)) attr_class_0 = dl.attribute(**data_batch, target_class_idx=0) attr_class_1 = dl.attribute(**data_batch, target_class_idx=1) - # Attributions should differ for different classes - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"]) ) diff --git a/tests/core/test_gim.py b/tests/core/test_gim.py index 284931883..4e7cfb753 100644 --- a/tests/core/test_gim.py +++ b/tests/core/test_gim.py @@ -282,8 +282,8 @@ def _manual_token_attribution( output = model.forward_from_embedding(codes=tuple(parts), label=labels) logits = output["logit"] - # Binary mode: target class 0 → sign = -1 (2*0 - 1) - target = (-1.0 * logits).sum() + # Binary (single logit): _resolve_target_indices always selects index 0. + target = logits.sum() model.zero_grad(set_to_none=True) if embeddings.grad is not None: diff --git a/tests/core/test_ig_gim.py b/tests/core/test_ig_gim.py index 38565227d..fed1b436f 100644 --- a/tests/core/test_ig_gim.py +++ b/tests/core/test_ig_gim.py @@ -567,7 +567,7 @@ def test_auto_target_class(self): self.assertEqual(attrs["codes"].shape, self.tokens.shape) def test_different_target_classes(self): - """Attributions for different target classes should differ.""" + """For binary (single logit), target_class_idx is a no-op.""" model = _ToyModel() ig_gim = IntegratedGradientGIM(model, temperature=1.0, steps=10) @@ -578,9 +578,10 @@ def test_different_target_classes(self): codes=self.tokens, label=self.labels, target_class_idx=1, )["codes"] - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attrs_0, attrs_1), - "Different target classes should give different attributions", + "Single-logit binary: target_class_idx is a no-op", ) # ----- Temporal tuple inputs ----- diff --git a/tests/core/test_integrated_gradients.py b/tests/core/test_integrated_gradients.py index 7c9212280..78ec4fab5 100644 --- a/tests/core/test_integrated_gradients.py +++ b/tests/core/test_integrated_gradients.py @@ -93,7 +93,7 @@ def test_basic_attribution(self): self.assertIsInstance(attributions["procedures"], torch.Tensor) def test_attribution_with_target_class(self): - """Test attribution computation with specific target class.""" + """For binary (single logit), target_class_idx is a no-op.""" ig = IntegratedGradients(self.model) data_batch = next(iter(self.test_loader)) @@ -103,8 +103,8 @@ def test_attribution_with_target_class(self): # Compute attributions for class 1 attr_class_1 = ig.attribute(**data_batch, target_class_idx=1, steps=10) - # Check that attributions are different for different classes - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"]) ) @@ -307,7 +307,7 @@ def test_attribution_shapes_stagenet(self): self.assertEqual(attributions[key].shape, value_tensor.shape) def test_attribution_with_target_class_stagenet(self): - """Test attribution with specific target class for StageNet.""" + """For binary (single logit), target_class_idx is a no-op.""" ig = IntegratedGradients(self.model) data_batch = next(iter(self.test_loader)) @@ -315,8 +315,8 @@ def test_attribution_with_target_class_stagenet(self): attr_0 = ig.attribute(**data_batch, target_class_idx=0, steps=10) attr_1 = ig.attribute(**data_batch, target_class_idx=1, steps=10) - # Check that attributions differ for different classes - self.assertFalse(torch.allclose(attr_0["codes"], attr_1["codes"])) + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue(torch.allclose(attr_0["codes"], attr_1["codes"])) def test_attribution_values_finite_stagenet(self): """Test that StageNet attributions are finite.""" diff --git a/tests/core/test_interp_metrics.py b/tests/core/test_interp_metrics.py index 9c46cf4ee..c415f1a9c 100644 --- a/tests/core/test_interp_metrics.py +++ b/tests/core/test_interp_metrics.py @@ -17,6 +17,7 @@ ComprehensivenessMetric, Evaluator, SufficiencyMetric, + threshold_sample_filter, ) from pyhealth.models import StageNet @@ -122,6 +123,7 @@ def setUp(self): # Initialize Integrated Gradients for attribution computation self.ig = IntegratedGradients(self.model, use_embeddings=True) + self.sample_filter = threshold_sample_filter() # Helper method to create attributions for a batch using IG def _create_attributions(self, batch, target_class_idx=1): @@ -227,7 +229,10 @@ def test_comprehensiveness_metric_basic(self): # Initialize metric comp = ComprehensivenessMetric( - self.model, percentages=[10, 20, 50], ablation_strategy="zero" + self.model, + percentages=[10, 20, 50], + ablation_strategy="zero", + sample_filter=self.sample_filter, ) # Compute scores - now returns (scores, valid_mask) tuple @@ -257,7 +262,10 @@ def test_sufficiency_metric_basic(self): # Initialize metric suff = SufficiencyMetric( - self.model, percentages=[10, 20, 50], ablation_strategy="zero" + self.model, + percentages=[10, 20, 50], + ablation_strategy="zero", + sample_filter=self.sample_filter, ) # Compute scores - now returns (scores, valid_mask) tuple @@ -282,7 +290,11 @@ def test_detailed_scores(self): """Test that detailed scores return per-percentage results.""" attributions = self._create_attributions(self.batch) - comp = ComprehensivenessMetric(self.model, percentages=[10, 20, 50]) + comp = ComprehensivenessMetric( + self.model, + percentages=[10, 20, 50], + sample_filter=self.sample_filter, + ) # Get detailed scores using return_per_percentage=True detailed = comp.compute(self.batch, attributions, return_per_percentage=True) @@ -305,7 +317,10 @@ def test_ablation_strategies(self): for strategy in strategies: comp = ComprehensivenessMetric( - self.model, percentages=[10, 20], ablation_strategy=strategy + self.model, + percentages=[10, 20], + ablation_strategy=strategy, + sample_filter=self.sample_filter, ) # Compute returns (scores, valid_mask) tuple scores, valid_mask = comp.compute(self.batch, attributions) @@ -414,7 +429,10 @@ def test_percentage_sensitivity(self): attributions = self._create_attributions(self.batch) comp = ComprehensivenessMetric( - self.model, percentages=[1, 10, 50], ablation_strategy="zero" + self.model, + percentages=[1, 10, 50], + ablation_strategy="zero", + sample_filter=self.sample_filter, ) detailed = comp.compute(self.batch, attributions, return_per_percentage=True) diff --git a/tests/core/test_lime.py b/tests/core/test_lime.py index ab061cb4a..55dcf0ff0 100644 --- a/tests/core/test_lime.py +++ b/tests/core/test_lime.py @@ -308,7 +308,7 @@ def test_target_class_idx_none(self): self.assertEqual(attributions["x"].shape, inputs.shape) def test_target_class_idx_specified(self): - """Should handle specific target class index.""" + """For binary (single logit), target_class_idx is a no-op.""" inputs = torch.tensor([[1.0, 0.5, -0.3]]) attr_class_0 = self.explainer.attribute( @@ -323,8 +323,8 @@ def test_target_class_idx_specified(self): target_class_idx=1, ) - # Attributions should differ for different classes - self.assertFalse(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) def test_attribution_values_are_finite(self): """Test that attribution values are finite (no NaN or Inf).""" @@ -777,7 +777,7 @@ def test_lime_mlp_basic_attribution(self): self.assertIsInstance(attributions["procedures"], torch.Tensor) def test_lime_mlp_with_target_class(self): - """Test LIME attribution with specific target class.""" + """For binary (single logit), target_class_idx is a no-op.""" explainer = LimeExplainer( self.model, use_embeddings=True, @@ -792,8 +792,8 @@ def test_lime_mlp_with_target_class(self): # Compute attributions for class 1 attr_class_1 = explainer.attribute(**data_batch, target_class_idx=1) - # Check that attributions are different for different classes - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"], atol=0.01) ) diff --git a/tests/core/test_shap.py b/tests/core/test_shap.py index 8c03a1c1f..25d703a0d 100644 --- a/tests/core/test_shap.py +++ b/tests/core/test_shap.py @@ -318,7 +318,7 @@ def test_target_class_idx_none(self): self.assertEqual(attributions["x"].shape, inputs.shape) def test_target_class_idx_specified(self): - """Should handle specific target class index.""" + """For binary (single logit), target_class_idx is a no-op.""" inputs = torch.tensor([[1.0, 0.5, -0.3]]) attr_class_0 = self.explainer.attribute( @@ -333,8 +333,8 @@ def test_target_class_idx_specified(self): target_class_idx=1, ) - # Attributions should differ for different classes - self.assertFalse(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) def test_attribution_values_are_finite(self): """Test that attribution values are finite (no NaN or Inf).""" @@ -696,7 +696,7 @@ def test_shap_mlp_basic_attribution(self): self.assertIsInstance(attributions["procedures"], torch.Tensor) def test_shap_mlp_with_target_class(self): - """Test SHAP attribution with specific target class.""" + """For binary (single logit), target_class_idx is a no-op.""" explainer = ShapExplainer(self.model) data_batch = next(iter(self.test_loader)) @@ -706,8 +706,8 @@ def test_shap_mlp_with_target_class(self): # Compute attributions for class 1 attr_class_1 = explainer.attribute(**data_batch, target_class_idx=1) - # Check that attributions are different for different classes - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"], atol=0.01) ) @@ -1023,23 +1023,13 @@ def test_kernel_weight_computation_edge_cases(self): self.assertTrue(torch.isfinite(weight_partial)) def test_target_prediction_extraction_binary(self): - """Test target prediction extraction for binary classification.""" - explainer = ShapExplainer( - self.model, - use_embeddings=False, - ) + """Test target prediction for binary classification via gather.""" # Single logit (binary classification) logits_binary = torch.tensor([[0.5], [1.0], [-0.3]]) - - # Class 1 target tensor - target_1 = torch.tensor([1, 1, 1]) - pred_1 = explainer._extract_target_prediction(logits_binary, target_1) - self.assertEqual(pred_1.shape, (3,)) - - # Class 0 target tensor - target_0 = torch.tensor([0, 0, 0]) - pred_0 = explainer._extract_target_prediction(logits_binary, target_0) - self.assertEqual(pred_0.shape, (3,)) + target_indices = torch.zeros(3, dtype=torch.long) + pred = logits_binary.gather(1, target_indices.unsqueeze(1)).squeeze(1) + self.assertEqual(pred.shape, (3,)) + torch.testing.assert_close(pred, torch.tensor([0.5, 1.0, -0.3])) def test_shape_mapping_simple(self): """Test mapping SHAP values back to input shapes.""" diff --git a/tests/core/test_transformer.py b/tests/core/test_transformer.py index a5fa6cc6b..e74468fc5 100644 --- a/tests/core/test_transformer.py +++ b/tests/core/test_transformer.py @@ -154,8 +154,7 @@ def test_chefer_relevance(self): relevance = CheferRelevance(model) # Test with explicitly specified class index - data_batch["class_index"] = 0 - scores = relevance.get_relevance_matrix(**data_batch) + scores = relevance.get_relevance_matrix(target_class_idx=0, **data_batch) # Verify that scores are returned for all feature keys self.assertIsInstance(scores, dict) @@ -167,9 +166,8 @@ def test_chefer_relevance(self): # Verify scores are non-negative (due to clamping in relevance computation) self.assertTrue(torch.all(scores[feature_key] >= 0)) - # Test without specifying class_index (should use predicted class) - data_batch_no_idx = {k: v for k, v in data_batch.items() if k != "class_index"} - scores_auto = relevance.get_relevance_matrix(**data_batch_no_idx) + # Test without specifying target_class_idx (should use predicted class) + scores_auto = relevance.get_relevance_matrix(**data_batch) # Verify that scores are returned self.assertIsInstance(scores_auto, dict)