Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions pyhealth/interpret/methods/integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class IntegratedGradients(BaseInterpreter):
... )
"""

def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50):
def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50, decision_threshold: float = 0.5):
"""Initialize IntegratedGradients interpreter.

Args:
Expand All @@ -181,6 +181,8 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 5
approximation of the path integral. Default is 50.
Can be overridden in attribute() calls. More steps lead to
better approximation but slower computation.
decision_threshold: Decision threshold used when inferring the default
target for binary and multilabel prediction. Default is 0.5.

Raises:
AssertionError: If use_embeddings=True but model does not
Expand All @@ -193,6 +195,7 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 5

self.use_embeddings = use_embeddings
self.steps = steps
self.decision_threshold = decision_threshold


def attribute(
Expand All @@ -217,9 +220,17 @@ def attribute(
the integral. If None, uses self.steps (set during
initialization). More steps lead to better approximation but
slower computation.
target_class_idx: Target class index for attribution
computation. If None, uses the predicted class (argmax of
model output).
target_class_idx: Target used for attribution computation.
Default behavior depends on prediction mode:
- binary: uses (sigmoid(logit) > decision_threshold)
- multiclass: uses argmax(logits)
- multilabel: uses (sigmoid(logits) > decision_threshold)
Notes:
- In binary mode, target_class_idx effectively behaves like a target
label (0 or 1), not a class-axis index.
- In multilabel mode, if target_class_idx is None, the default target is
a multi-hot mask of all predicted-positive labels, and attribution is
computed for the sum of those selected logits.
**kwargs: Input data dictionary from a dataloader batch
containing:
- Feature keys (e.g., 'conditions', 'procedures'):
Expand Down Expand Up @@ -329,7 +340,7 @@ def attribute(
if target_class_idx is not None:
target = torch.tensor([target_class_idx], device=device)
else:
target = (torch.sigmoid(base_logits) > 0.5).long()
target = (torch.sigmoid(base_logits) > self.decision_threshold).long()
elif mode == "multiclass":
if target_class_idx is not None:
target = F.one_hot(
Expand All @@ -348,7 +359,7 @@ def attribute(
num_classes=base_logits.shape[-1],
).float()
else:
target = (torch.sigmoid(base_logits) > 0.5).float()
target = (torch.sigmoid(base_logits) > self.decision_threshold).float()
else:
raise ValueError(
"Unsupported prediction mode for Integrated Gradients attribution."
Expand Down Expand Up @@ -555,16 +566,19 @@ def _compute_target_output(
"""Compute scalar target output for backpropagation.

Creates a differentiable scalar from the model logits that,
when differentiated, gives the gradient of the target class
logit w.r.t. the input.
when differentiated, gives the gradient of the selected target logit(s)
w.r.t. the input.

Args:
logits: Model output logits, shape [batch, num_classes] or
[batch, 1].
target: Target tensor. For binary: [batch] or [1] with 0/1
class indices. For multiclass/multilabel: [batch, num_classes]
one-hot or multi-hot tensor.
target: Target tensor. For binary: [batch] or [1] with 0/1 target labels.
For multiclass: [batch, num_classes] one-hot tensor.
For multilabel: [batch, num_classes] one-hot or multi-hot tensor.

In multilabel mode, a multi-hot target corresponds to the sum of
the selected logits.

Returns:
Scalar tensor for backpropagation.
"""
Expand Down