diff --git a/src/metrax/__init__.py b/src/metrax/__init__.py index eb85737..5b7b151 100644 --- a/src/metrax/__init__.py +++ b/src/metrax/__init__.py @@ -34,12 +34,14 @@ MAE = regression_metrics.MAE MRR = ranking_metrics.MRR MSE = regression_metrics.MSE +MSLE = regression_metrics.MSLE NDCGAtK = ranking_metrics.NDCGAtK Perplexity = nlp_metrics.Perplexity Precision = classification_metrics.Precision PrecisionAtK = ranking_metrics.PrecisionAtK PSNR = image_metrics.PSNR RMSE = regression_metrics.RMSE +RMSLE = regression_metrics.RMSLE RSQUARED = regression_metrics.RSQUARED Recall = classification_metrics.Recall RecallAtK = ranking_metrics.RecallAtK @@ -66,12 +68,14 @@ "MAE", "MRR", "MSE", + "MSLE", "NDCGAtK", "Perplexity", "Precision", "PrecisionAtK", "PSNR", "RMSE", + "RMSLE", "RSQUARED", "SpearmanRankCorrelation", "Recall", diff --git a/src/metrax/metrax_test.py b/src/metrax/metrax_test.py index b6a5ce1..ad1dcf1 100644 --- a/src/metrax/metrax_test.py +++ b/src/metrax/metrax_test.py @@ -148,6 +148,12 @@ class MetraxTest(parameterized.TestCase): metrax.MSE, {'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS}, ), + ( + 'msle', + metrax.MSLE, + {'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS}, + ), + ( 'ndcgAtK', metrax.NDCGAtK, @@ -190,6 +196,11 @@ class MetraxTest(parameterized.TestCase): metrax.RMSE, {'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS}, ), + ( + 'rmsle', + metrax.RMSLE, + {'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS}, + ), ( 'rsquared', metrax.RSQUARED, diff --git a/src/metrax/nnx/__init__.py b/src/metrax/nnx/__init__.py index 07b40f2..df30ccf 100644 --- a/src/metrax/nnx/__init__.py +++ b/src/metrax/nnx/__init__.py @@ -28,12 +28,14 @@ MAE = nnx_metrics.MAE MRR = nnx_metrics.MRR MSE = nnx_metrics.MSE +MSLE = nnx_metrics.MSLE NDCGAtK = nnx_metrics.NDCGAtK Perplexity = nnx_metrics.Perplexity Precision = nnx_metrics.Precision PrecisionAtK = nnx_metrics.PrecisionAtK PSNR = nnx_metrics.PSNR RMSE = nnx_metrics.RMSE +RMSLE = nnx_metrics.RMSLE RSQUARED = nnx_metrics.RSQUARED Recall = nnx_metrics.Recall RecallAtK = nnx_metrics.RecallAtK diff --git a/src/metrax/nnx/nnx_metrics.py b/src/metrax/nnx/nnx_metrics.py index afdb97f..5177d46 100644 --- a/src/metrax/nnx/nnx_metrics.py +++ b/src/metrax/nnx/nnx_metrics.py @@ -114,6 +114,13 @@ def __init__(self): super().__init__(metrax.MSE) +class MSLE(NnxWrapper): + """An NNX class for the Metrax metric MSLE.""" + + def __init__(self): + super().__init__(metrax.MSLE) + + class NDCGAtK(NnxWrapper): """An NNX class for the Metrax metric NDCGAtK.""" @@ -170,6 +177,13 @@ def __init__(self): super().__init__(metrax.RMSE) +class RMSLE(NnxWrapper): + """An NNX class for the Metrax metric RMSLE.""" + + def __init__(self): + super().__init__(metrax.RMSLE) + + class RougeL(NnxWrapper): """An NNX class for the Metrax metric RougeL.""" diff --git a/src/metrax/regression_metrics.py b/src/metrax/regression_metrics.py index 8bb839d..a75e921 100644 --- a/src/metrax/regression_metrics.py +++ b/src/metrax/regression_metrics.py @@ -160,6 +160,74 @@ def compute(self) -> jax.Array: return jnp.sqrt(super().compute()) +@flax.struct.dataclass +class MSLE(base.Average): + r"""Computes the mean squared logarithmic error for regression problems given `predictions` and `labels`. + + The mean squared logarithmic error is defined as: + + .. math:: + MSLE = \frac{1}{N} \sum_{i=1}^{N} (ln(y_i + 1) - ln(\hat{y}_i + 1))^2 + + where: + - :math:`y_i` are true values + - :math:`\hat{y}_i` are predictions + - :math:`N` is the number of samples + """ + + @classmethod + def from_model_output( + cls, + predictions: jax.Array, + labels: jax.Array, + sample_weights: jax.Array | None = None, + ) -> 'MSLE': + """Updates the metric. + + Args: + predictions: A floating point 1D vector representing the prediction + generated from the model. The shape should be (batch_size,). + labels: True value. The shape should be (batch_size,). + sample_weights: An optional floating point 1D vector representing the + weight of each sample. The shape should be (batch_size,). + + Returns: + Updated MSLE metric. The shape should be a single scalar. + """ + log_predictions = jnp.log1p(predictions) + log_labels = jnp.log1p(labels) + squared_error = jnp.square(log_predictions - log_labels) + count = jnp.ones_like(labels, dtype=jnp.int32) + if sample_weights is not None: + squared_error = squared_error * sample_weights + count = count * sample_weights + return cls( + total=squared_error.sum(), + count=count.sum(), + ) + + +@flax.struct.dataclass +class RMSLE(MSLE): + r"""Computes the root mean squared logarithmic error for regression problems given `predictions` and `labels`. + + The root mean squared logarithmic error is defined as: + + .. math:: + RMSLE = \sqrt{\frac{1}{N} \sum_{i=1}^{N} + (ln(y_i + 1) - ln(\hat{y}_i + 1))^2 + } + + where: + - :math:`y_i` are true values + - :math:`\hat{y}_i` are predictions + - :math:`N` is the number of samples + """ + + def compute(self) -> jax.Array: + return jnp.sqrt(super().compute()) + + @flax.struct.dataclass class RSQUARED(clu_metrics.Metric): r"""Computes the r-squared score of a scalar or a batch of tensors. diff --git a/src/metrax/regression_metrics_test.py b/src/metrax/regression_metrics_test.py index 0b91277..cb75fda 100644 --- a/src/metrax/regression_metrics_test.py +++ b/src/metrax/regression_metrics_test.py @@ -282,6 +282,86 @@ def test_rmse(self, y_true, y_pred, sample_weights): atol=atol, ) + @parameterized.named_parameters( + ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None), + ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None), + ('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, None), + ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), + ('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS), + ('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS), + ('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS), + ) + def test_msle(self, y_true, y_pred, sample_weights): + """Test that `MSLE` Metric computes correct values.""" + y_true = y_true.astype(y_pred.dtype) + y_pred = y_pred.astype(y_true.dtype) + if sample_weights is None: + sample_weights = np.ones_like(y_true) + + metric = None + for labels, logits, weights in zip(y_true, y_pred, sample_weights): + update = metrax.MSLE.from_model_output( + predictions=logits, + labels=labels, + sample_weights=weights, + ) + metric = update if metric is None else metric.merge(update) + + expected = sklearn_metrics.mean_squared_log_error( + y_true.astype('float32').flatten(), + y_pred.astype('float32').flatten(), + sample_weight=sample_weights.astype('float32').flatten(), + ) + # Use lower tolerance for lower precision dtypes. + rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05 + atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05 + np.testing.assert_allclose( + metric.compute(), + expected, + rtol=rtol, + atol=atol, + ) + + @parameterized.named_parameters( + ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None), + ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None), + ('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, None), + ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None), + ('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS), + ('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS), + ('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS), + ) + def test_rmsle(self, y_true, y_pred, sample_weights): + """Test that `RMSLE` Metric computes correct values.""" + y_true = y_true.astype(y_pred.dtype) + y_pred = y_pred.astype(y_true.dtype) + if sample_weights is None: + sample_weights = np.ones_like(y_true) + + metric = None + for labels, logits, weights in zip(y_true, y_pred, sample_weights): + update = metrax.RMSLE.from_model_output( + predictions=logits, + labels=labels, + sample_weights=weights, + ) + metric = update if metric is None else metric.merge(update) + + expected = sklearn_metrics.root_mean_squared_log_error( + y_true.astype('float32').flatten(), + y_pred.astype('float32').flatten(), + sample_weight=sample_weights.astype('float32').flatten(), + ) + # Use lower tolerance for lower precision dtypes. + rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05 + atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05 + np.testing.assert_allclose( + metric.compute(), + expected, + rtol=rtol, + atol=atol, + ) + @parameterized.named_parameters( ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None), ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),