From 2bde2a32cf9fcac5942627e578531940edc611e9 Mon Sep 17 00:00:00 2001 From: Cade Mack <24661281+cademack@users.noreply.github.com> Date: Thu, 26 Mar 2026 11:03:09 -0400 Subject: [PATCH] Add helper methods to get cleaned contributions/residuals --- howso/client/schemas/aggregate_reaction.py | 78 +++++++++++++++++++++- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/howso/client/schemas/aggregate_reaction.py b/howso/client/schemas/aggregate_reaction.py index 8edbd239..f34f1b3a 100644 --- a/howso/client/schemas/aggregate_reaction.py +++ b/howso/client/schemas/aggregate_reaction.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Iterator from copy import deepcopy from pprint import pformat -from typing import Any, Literal, overload, TypeAlias, TypeVar +from typing import Any, Literal, overload, TypeAlias, TypeVar, get_args import pandas as pd @@ -55,6 +55,23 @@ ] """Metric output keys of react aggregate that can be combined together into a DataFrame.""" +ResidualTypes: TypeAlias = Literal[ + "feature_full_residuals", + "feature_robust_residuals", + "feature_deviations", +] + +FeatureContributionTypes: TypeAlias = Literal[ + "feature_full_prediction_contributions", + "feature_full_directional_prediction_contributions", + "feature_robust_prediction_contributions", + "feature_robust_directional_prediction_contributions", + "feature_full_accuracy_contributions", + "feature_full_accuracy_contributions_permutation", + "feature_robust_accuracy_contributions", + "feature_robust_accuracy_contributions_permutation", +] + Metric: TypeAlias = Literal[ComplexMetric, TableMetric] """All metric output keys of react aggregate.""" @@ -84,7 +101,7 @@ def __getitem__(self, key: Literal["confusion_matrix"]) -> dict[str, ConfusionMa @overload def __getitem__(self, key: TableMetric) -> pd.DataFrame: ... - def __getitem__(self, key: Metric) -> Any: + def __getitem__(self, key: Metric) -> Any | None: value = self._data[key] if isinstance(value, Mapping): if key == "confusion_matrix": @@ -124,10 +141,65 @@ def get(self, key: TableMetric, /) -> pd.DataFrame | None: ... def get(self, key: TableMetric, /, default: _VT) -> pd.DataFrame | _VT: ... def get( # pyright: ignore[reportIncompatibleMethodOverride] - self, key: Metric, /, default: _VT | None = None + self, key: Metric, /, default: _VT | None = None, ) -> MetricValue | _VT | None: return super().get(key, default=default) + def get_feature_residuals( + self, + null_residuals: bool = False, + ) -> pd.DataFrame: + """ + Get the computed feature residuals as a DataFrame. + + Parameters + ---------- + null_residuals : bool, default False + A flag indicating if the residuals for the nullness of features should be returned rather than + the residual for the values of non-null cases. + + Returns + ------- + DataFrame + The DataFrame representation of the computed feature residuals. + """ + data = {x: self._data[x] for x in get_args(ResidualTypes) if x in self._data} + if null_residuals: + def map_func(x): + return x[1] if isinstance(x, list) else None + else: + def map_func(x): + return x[0] if isinstance(x, list) else x + + return pd.DataFrame(data).map(map_func) + + def get_feature_contributions( + self, key: FeatureContributionTypes, + null_contributions=False, + ) -> pd.DataFrame: + """ + Get the computed feature contributions as a DataFrame. + + Parameters + ---------- + null_contributions : bool, default False + A flag indicating if the contributions for the nullness of features should be returned rather than + the residual for the values of non-null cases. + + Returns + ------- + DataFrame + The DataFrame representation of the computed feature contributions. + """ + value = self._data[key] + if null_contributions: + def map_func(x): + return x[1] if isinstance(x, list) else None + else: + def map_func(x): + return x[0] if isinstance(x, list) else x + return pd.DataFrame(value).map(map_func) + def __iter__(self) -> Iterator[Metric]: """Iterate over the keys.""" return iter(self._data)