Description
The FBetaScore metric in metrax.classification_metrics does not implement the merge method. This causes a NotImplementedError when using the metric with the metrax.nnx wrapper (e.g., metrax.nnx.FBetaScore), as the wrapper's update method relies on merge to combine metrics.
It appears the merge method is commented out in the source code of metrax/classification_metrics.py.
|
""" |
|
This function is currently unused as the 'from_model_output' function can handle the whole |
|
dataset without needing to split and merge them. I'm leaving this here for now incase we want to |
|
repurpose this or need to change something that requires this function's use again. This function would need |
|
to be reworked for it to work with the current implementation of this class. |
|
""" |
|
# # Merge datasets together |
|
# def merge(self, other: 'FBetaScore') -> 'FBetaScore': |
|
# |
|
# # Check if the incoming beta is the same value as the current beta |
|
# if other.beta == self.beta: |
|
# return type(self)( |
|
# true_positives = self.true_positives + other.true_positives, |
|
# false_positives = self.false_positives + other.false_positives, |
|
# false_negatives = self.false_negatives + other.false_negatives, |
|
# beta=self.beta, |
|
# ) |
|
# else: |
|
# raise ValueError('The "Beta" values between the two are not equal.') |
Minimal Reproduction
import metrax.nnx
import jax.numpy as jnp
import jax.random
# Setup dummy data
predictions = jax.random.normal(jax.random.PRNGKey(0), (3,))
labels = jnp.arange(3) % 2
# Initialize and update metric
f1_metric = metrax.nnx.FBetaScore()
f1_metric.update(predictions=predictions, labels=labels) # Raises NotImplementedError
Traceback
Traceback (most recent call last):
File "repro.py", line 10, in <module>
f1_metric.update(predictions=predictions, labels=labels)
File ".../site-packages/metrax/nnx/nnx_wrapper.py", line 31, in update
self.clu_metric = self.clu_metric.merge(other_clu_metric)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/clu/metrics.py", line 148, in merge
raise NotImplementedError("Must override merge()")
NotImplementedError: Must override merge()
Description
The
FBetaScoremetric inmetrax.classification_metricsdoes not implement themergemethod. This causes aNotImplementedErrorwhen using the metric with themetrax.nnxwrapper (e.g.,metrax.nnx.FBetaScore), as the wrapper'supdatemethod relies onmergeto combine metrics.It appears the
mergemethod is commented out in the source code ofmetrax/classification_metrics.py.metrax/src/metrax/classification_metrics.py
Lines 677 to 695 in 4ff6ccf
Minimal Reproduction
Traceback