Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions howso/client/schemas/aggregate_reaction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Mapping, Iterator
from copy import deepcopy
from pprint import pformat
from typing import Any, Literal, overload, TypeAlias, TypeVar
from typing import Any, Literal, overload, TypeAlias, TypeVar, get_args

import pandas as pd

Expand Down Expand Up @@ -55,6 +55,23 @@
]
"""Metric output keys of react aggregate that can be combined together into a DataFrame."""

ResidualTypes: TypeAlias = Literal[
"feature_full_residuals",
"feature_robust_residuals",
"feature_deviations",
]

FeatureContributionTypes: TypeAlias = Literal[
"feature_full_prediction_contributions",
"feature_full_directional_prediction_contributions",
"feature_robust_prediction_contributions",
"feature_robust_directional_prediction_contributions",
"feature_full_accuracy_contributions",
"feature_full_accuracy_contributions_permutation",
"feature_robust_accuracy_contributions",
"feature_robust_accuracy_contributions_permutation",
]

Metric: TypeAlias = Literal[ComplexMetric, TableMetric]
"""All metric output keys of react aggregate."""

Expand Down Expand Up @@ -84,7 +101,7 @@ def __getitem__(self, key: Literal["confusion_matrix"]) -> dict[str, ConfusionMa
@overload
def __getitem__(self, key: TableMetric) -> pd.DataFrame: ...

def __getitem__(self, key: Metric) -> Any:
def __getitem__(self, key: Metric) -> Any | None:
value = self._data[key]
if isinstance(value, Mapping):
if key == "confusion_matrix":
Expand Down Expand Up @@ -124,10 +141,65 @@ def get(self, key: TableMetric, /) -> pd.DataFrame | None: ...
def get(self, key: TableMetric, /, default: _VT) -> pd.DataFrame | _VT: ...

def get( # pyright: ignore[reportIncompatibleMethodOverride]
self, key: Metric, /, default: _VT | None = None
self, key: Metric, /, default: _VT | None = None,
) -> MetricValue | _VT | None:
return super().get(key, default=default)

def get_feature_residuals(
self,
null_residuals: bool = False,
) -> pd.DataFrame:
"""
Get the computed feature residuals as a DataFrame.

Parameters
----------
null_residuals : bool, default False
A flag indicating if the residuals for the nullness of features should be returned rather than
the residual for the values of non-null cases.

Returns
-------
DataFrame
The DataFrame representation of the computed feature residuals.
"""
data = {x: self._data[x] for x in get_args(ResidualTypes) if x in self._data}
if null_residuals:
def map_func(x):
return x[1] if isinstance(x, list) else None
else:
def map_func(x):
return x[0] if isinstance(x, list) else x

return pd.DataFrame(data).map(map_func)

def get_feature_contributions(
self, key: FeatureContributionTypes,
null_contributions=False,
) -> pd.DataFrame:
"""
Get the computed feature contributions as a DataFrame.

Parameters
----------
null_contributions : bool, default False
A flag indicating if the contributions for the nullness of features should be returned rather than
the residual for the values of non-null cases.

Returns
-------
DataFrame
The DataFrame representation of the computed feature contributions.
"""
value = self._data[key]
if null_contributions:
def map_func(x):
return x[1] if isinstance(x, list) else None
else:
def map_func(x):
return x[0] if isinstance(x, list) else x
return pd.DataFrame(value).map(map_func)

def __iter__(self) -> Iterator[Metric]:
"""Iterate over the keys."""
return iter(self._data)
Expand Down
Loading