From 18de846b7b5c3fec37398d6a93f4e989b6fd5a8e Mon Sep 17 00:00:00 2001 From: ParthAgarwalCode Date: Sat, 4 Apr 2026 23:51:37 -0400 Subject: [PATCH] Clarify IntegratedGradients target semantics for binary and multilabel settings This PR clarifies IntegratedGradients target semantics across prediction modes. Changes: - add `decision_threshold` for binary/multilabel default target inference - use `decision_threshold` instead of hardcoded `0.5` - document binary default behavior - document multilabel default behavior - clarify that `target_class_idx` in binary behaves like target label selection - clarify that multilabel default attribution uses the sum of selected logits This does not change the IG algorithm itself, and preserves existing default behavior when `decision_threshold=0.5`. --- .../interpret/methods/integrated_gradients.py | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/pyhealth/interpret/methods/integrated_gradients.py b/pyhealth/interpret/methods/integrated_gradients.py index a529a6f3f..52f379192 100644 --- a/pyhealth/interpret/methods/integrated_gradients.py +++ b/pyhealth/interpret/methods/integrated_gradients.py @@ -166,7 +166,7 @@ class IntegratedGradients(BaseInterpreter): ... ) """ - def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50): + def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50, decision_threshold: float = 0.5): """Initialize IntegratedGradients interpreter. Args: @@ -181,6 +181,8 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 5 approximation of the path integral. Default is 50. Can be overridden in attribute() calls. More steps lead to better approximation but slower computation. + decision_threshold: Decision threshold used when inferring the default + target for binary and multilabel prediction. Default is 0.5. Raises: AssertionError: If use_embeddings=True but model does not @@ -193,6 +195,7 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 5 self.use_embeddings = use_embeddings self.steps = steps + self.decision_threshold = decision_threshold def attribute( @@ -217,9 +220,17 @@ def attribute( the integral. If None, uses self.steps (set during initialization). More steps lead to better approximation but slower computation. - target_class_idx: Target class index for attribution - computation. If None, uses the predicted class (argmax of - model output). + target_class_idx: Target used for attribution computation. + Default behavior depends on prediction mode: + - binary: uses (sigmoid(logit) > decision_threshold) + - multiclass: uses argmax(logits) + - multilabel: uses (sigmoid(logits) > decision_threshold) + Notes: + - In binary mode, target_class_idx effectively behaves like a target + label (0 or 1), not a class-axis index. + - In multilabel mode, if target_class_idx is None, the default target is + a multi-hot mask of all predicted-positive labels, and attribution is + computed for the sum of those selected logits. **kwargs: Input data dictionary from a dataloader batch containing: - Feature keys (e.g., 'conditions', 'procedures'): @@ -329,7 +340,7 @@ def attribute( if target_class_idx is not None: target = torch.tensor([target_class_idx], device=device) else: - target = (torch.sigmoid(base_logits) > 0.5).long() + target = (torch.sigmoid(base_logits) > self.decision_threshold).long() elif mode == "multiclass": if target_class_idx is not None: target = F.one_hot( @@ -348,7 +359,7 @@ def attribute( num_classes=base_logits.shape[-1], ).float() else: - target = (torch.sigmoid(base_logits) > 0.5).float() + target = (torch.sigmoid(base_logits) > self.decision_threshold).float() else: raise ValueError( "Unsupported prediction mode for Integrated Gradients attribution." @@ -555,16 +566,19 @@ def _compute_target_output( """Compute scalar target output for backpropagation. Creates a differentiable scalar from the model logits that, - when differentiated, gives the gradient of the target class - logit w.r.t. the input. + when differentiated, gives the gradient of the selected target logit(s) + w.r.t. the input. Args: logits: Model output logits, shape [batch, num_classes] or [batch, 1]. - target: Target tensor. For binary: [batch] or [1] with 0/1 - class indices. For multiclass/multilabel: [batch, num_classes] - one-hot or multi-hot tensor. + target: Target tensor. For binary: [batch] or [1] with 0/1 target labels. + For multiclass: [batch, num_classes] one-hot tensor. + For multilabel: [batch, num_classes] one-hot or multi-hot tensor. + In multilabel mode, a multi-hot target corresponds to the sum of + the selected logits. + Returns: Scalar tensor for backpropagation. """