Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/metrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
MAE = regression_metrics.MAE
MRR = ranking_metrics.MRR
MSE = regression_metrics.MSE
MSLE = regression_metrics.MSLE
NDCGAtK = ranking_metrics.NDCGAtK
Perplexity = nlp_metrics.Perplexity
Precision = classification_metrics.Precision
PrecisionAtK = ranking_metrics.PrecisionAtK
PSNR = image_metrics.PSNR
RMSE = regression_metrics.RMSE
RMSLE = regression_metrics.RMSLE
RSQUARED = regression_metrics.RSQUARED
Recall = classification_metrics.Recall
RecallAtK = ranking_metrics.RecallAtK
Expand All @@ -66,12 +68,14 @@
"MAE",
"MRR",
"MSE",
"MSLE",
"NDCGAtK",
"Perplexity",
"Precision",
"PrecisionAtK",
"PSNR",
"RMSE",
"RMSLE",
"RSQUARED",
"SpearmanRankCorrelation",
"Recall",
Expand Down
11 changes: 11 additions & 0 deletions src/metrax/metrax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ class MetraxTest(parameterized.TestCase):
metrax.MSE,
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
),
(
'msle',
metrax.MSLE,
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
),

(
'ndcgAtK',
metrax.NDCGAtK,
Expand Down Expand Up @@ -190,6 +196,11 @@ class MetraxTest(parameterized.TestCase):
metrax.RMSE,
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
),
(
'rmsle',
metrax.RMSLE,
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
),
(
'rsquared',
metrax.RSQUARED,
Expand Down
2 changes: 2 additions & 0 deletions src/metrax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
MAE = nnx_metrics.MAE
MRR = nnx_metrics.MRR
MSE = nnx_metrics.MSE
MSLE = nnx_metrics.MSLE
NDCGAtK = nnx_metrics.NDCGAtK
Perplexity = nnx_metrics.Perplexity
Precision = nnx_metrics.Precision
PrecisionAtK = nnx_metrics.PrecisionAtK
PSNR = nnx_metrics.PSNR
RMSE = nnx_metrics.RMSE
RMSLE = nnx_metrics.RMSLE
RSQUARED = nnx_metrics.RSQUARED
Recall = nnx_metrics.Recall
RecallAtK = nnx_metrics.RecallAtK
Expand Down
14 changes: 14 additions & 0 deletions src/metrax/nnx/nnx_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def __init__(self):
super().__init__(metrax.MSE)


class MSLE(NnxWrapper):
"""An NNX class for the Metrax metric MSLE."""

def __init__(self):
super().__init__(metrax.MSLE)


class NDCGAtK(NnxWrapper):
"""An NNX class for the Metrax metric NDCGAtK."""

Expand Down Expand Up @@ -170,6 +177,13 @@ def __init__(self):
super().__init__(metrax.RMSE)


class RMSLE(NnxWrapper):
"""An NNX class for the Metrax metric RMSLE."""

def __init__(self):
super().__init__(metrax.RMSLE)


class RougeL(NnxWrapper):
"""An NNX class for the Metrax metric RougeL."""

Expand Down
68 changes: 68 additions & 0 deletions src/metrax/regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,74 @@ def compute(self) -> jax.Array:
return jnp.sqrt(super().compute())


@flax.struct.dataclass
class MSLE(base.Average):
r"""Computes the mean squared logarithmic error for regression problems given `predictions` and `labels`.

The mean squared logarithmic error is defined as:

.. math::
MSLE = \frac{1}{N} \sum_{i=1}^{N} (ln(y_i + 1) - ln(\hat{y}_i + 1))^2

where:
- :math:`y_i` are true values
- :math:`\hat{y}_i` are predictions
- :math:`N` is the number of samples
"""

@classmethod
def from_model_output(
cls,
predictions: jax.Array,
labels: jax.Array,
sample_weights: jax.Array | None = None,
) -> 'MSLE':
"""Updates the metric.

Args:
predictions: A floating point 1D vector representing the prediction
generated from the model. The shape should be (batch_size,).
labels: True value. The shape should be (batch_size,).
sample_weights: An optional floating point 1D vector representing the
weight of each sample. The shape should be (batch_size,).

Returns:
Updated MSLE metric. The shape should be a single scalar.
"""
log_predictions = jnp.log1p(predictions)
log_labels = jnp.log1p(labels)
squared_error = jnp.square(log_predictions - log_labels)
count = jnp.ones_like(labels, dtype=jnp.int32)
if sample_weights is not None:
squared_error = squared_error * sample_weights
count = count * sample_weights
return cls(
total=squared_error.sum(),
count=count.sum(),
)


@flax.struct.dataclass
class RMSLE(MSLE):
r"""Computes the root mean squared logarithmic error for regression problems given `predictions` and `labels`.

The root mean squared logarithmic error is defined as:

.. math::
RMSLE = \sqrt{\frac{1}{N} \sum_{i=1}^{N}
(ln(y_i + 1) - ln(\hat{y}_i + 1))^2
}

where:
- :math:`y_i` are true values
- :math:`\hat{y}_i` are predictions
- :math:`N` is the number of samples
"""

def compute(self) -> jax.Array:
return jnp.sqrt(super().compute())


@flax.struct.dataclass
class RSQUARED(clu_metrics.Metric):
r"""Computes the r-squared score of a scalar or a batch of tensors.
Expand Down
80 changes: 80 additions & 0 deletions src/metrax/regression_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,86 @@ def test_rmse(self, y_true, y_pred, sample_weights):
atol=atol,
)

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, None),
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS),
)
def test_msle(self, y_true, y_pred, sample_weights):
"""Test that `MSLE` Metric computes correct values."""
y_true = y_true.astype(y_pred.dtype)
y_pred = y_pred.astype(y_true.dtype)
if sample_weights is None:
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
update = metrax.MSLE.from_model_output(
predictions=logits,
labels=labels,
sample_weights=weights,
)
metric = update if metric is None else metric.merge(update)

expected = sklearn_metrics.mean_squared_log_error(
y_true.astype('float32').flatten(),
y_pred.astype('float32').flatten(),
sample_weight=sample_weights.astype('float32').flatten(),
)
# Use lower tolerance for lower precision dtypes.
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
np.testing.assert_allclose(
metric.compute(),
expected,
rtol=rtol,
atol=atol,
)

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, None),
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS),
)
def test_rmsle(self, y_true, y_pred, sample_weights):
"""Test that `RMSLE` Metric computes correct values."""
y_true = y_true.astype(y_pred.dtype)
y_pred = y_pred.astype(y_true.dtype)
if sample_weights is None:
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
update = metrax.RMSLE.from_model_output(
predictions=logits,
labels=labels,
sample_weights=weights,
)
metric = update if metric is None else metric.merge(update)

expected = sklearn_metrics.root_mean_squared_log_error(
y_true.astype('float32').flatten(),
y_pred.astype('float32').flatten(),
sample_weight=sample_weights.astype('float32').flatten(),
)
# Use lower tolerance for lower precision dtypes.
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
np.testing.assert_allclose(
metric.compute(),
expected,
rtol=rtol,
atol=atol,
)

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
Expand Down