diff --git a/src/metrax/__init__.py b/src/metrax/__init__.py index 488b738..eb85737 100644 --- a/src/metrax/__init__.py +++ b/src/metrax/__init__.py @@ -26,6 +26,7 @@ Average = base.Average AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK BLEU = nlp_metrics.BLEU +CosineSimilarity = image_metrics.CosineSimilarity DCGAtK = ranking_metrics.DCGAtK Dice = image_metrics.Dice FBetaScore = classification_metrics.FBetaScore @@ -57,6 +58,7 @@ "Average", "AveragePrecisionAtK", "BLEU", + "CosineSimilarity", "DCGAtK", "Dice", "FBetaScore", diff --git a/src/metrax/image_metrics.py b/src/metrax/image_metrics.py index 96bb244..c19a75e 100644 --- a/src/metrax/image_metrics.py +++ b/src/metrax/image_metrics.py @@ -666,3 +666,43 @@ def compute(self) -> jax.Array: """Returns the final Dice coefficient.""" epsilon = 1e-7 return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon) + + +@flax.struct.dataclass +class CosineSimilarity(base.Average): + r"""Calculates the Cosine Similarity between two arrays. + + The Cosine Similarity is defined as the dot product of the vectors divided + by the product of their magnitudes (norms). + + .. math:: + cos_{sim}(x,y) = \frac{x \cdot y}{||x|| * ||y||} + """ + + @classmethod + def from_model_output( + cls, + predictions: jax.Array, + targets: jax.Array, + axis: int = -1, + ) -> 'CosineSimilarity': + """Creates a CosineSimilarity instance. + + Args: + predictions: A floating point array of the predictions. The shape should + be (batch_size,). + targets: A floating point array of the targets. The shape should be + (batch_size,). + axis: The axis to compute the norm over. + + Returns: + A `CosineSimilarity` instance. + """ + dot_product = (predictions * targets).sum(axis=axis) + predictions_norm = jnp.linalg.norm(predictions, ord=2, axis=axis) + targets_norm = jnp.linalg.norm(targets, ord=2, axis=axis) + + cosine_similarity = dot_product / (predictions_norm * targets_norm) + + return super().from_model_output(values=cosine_similarity) + diff --git a/src/metrax/image_metrics_test.py b/src/metrax/image_metrics_test.py index 1013990..a3f9fbf 100644 --- a/src/metrax/image_metrics_test.py +++ b/src/metrax/image_metrics_test.py @@ -553,6 +553,47 @@ def test_dice(self, y_true, y_pred): np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + @parameterized.named_parameters( + ( + 'cosine_similarity_basic_f32', + PREDS_1, + TARGETS_1, + ), + ( + 'cosine_similarity_multichannel_norm', + PREDS_2, + TARGETS_2, + ), + ( + 'cosine_similarity_uint8_range_single_channel', + PREDS_3, + TARGETS_3, + ), + ( + 'cosine_similarity_identical_images', + PREDS_4, + TARGETS_4, + ), + ( + 'cosine_similarity_large_batch', + PREDS_6, + TARGETS_6, + ), + ) + def test_cosine_similarity_against_keras(self, predictions, targets): + """Test that CosineSimilarity computes expected values.""" + predictions = jnp.array(predictions) + targets = jnp.array(targets) + keras_cosine_similarity_metric = keras.metrics.CosineSimilarity() + keras_cosine_similarity_metric.update_state(predictions, targets) + expected = keras_cosine_similarity_metric.result() + + metric = metrax.CosineSimilarity.from_model_output( + predictions=predictions, targets=targets + ) + result = metric.compute() + + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) if __name__ == '__main__': absltest.main() diff --git a/src/metrax/metrax_test.py b/src/metrax/metrax_test.py index fe94764..b6a5ce1 100644 --- a/src/metrax/metrax_test.py +++ b/src/metrax/metrax_test.py @@ -92,6 +92,14 @@ class MetraxTest(parameterized.TestCase): 'ks': KS, }, ), + ( + 'cosinesimilarity', + metrax.CosineSimilarity, + { + 'predictions': OUTPUT_LABELS, + 'targets': OUTPUT_PREDS, + }, + ), ( 'dcgAtK', metrax.DCGAtK, diff --git a/src/metrax/nnx/__init__.py b/src/metrax/nnx/__init__.py index 46053f7..07b40f2 100644 --- a/src/metrax/nnx/__init__.py +++ b/src/metrax/nnx/__init__.py @@ -20,6 +20,7 @@ Average = nnx_metrics.Average AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK BLEU = nnx_metrics.BLEU +CosineSimilarity = nnx_metrics.CosineSimilarity DCGAtK = nnx_metrics.DCGAtK Dice = nnx_metrics.Dice FBetaScore = nnx_metrics.FBetaScore @@ -50,6 +51,7 @@ "Average", "AveragePrecisionAtK", "BLEU", + "CosineSimilarity", "DCGAtK", "Dice", "FBetaScore", diff --git a/src/metrax/nnx/nnx_metrics.py b/src/metrax/nnx/nnx_metrics.py index 2461607..afdb97f 100644 --- a/src/metrax/nnx/nnx_metrics.py +++ b/src/metrax/nnx/nnx_metrics.py @@ -60,6 +60,13 @@ def __init__(self): super().__init__(metrax.BLEU) +class CosineSimilarity(NnxWrapper): + """An NNX class for the Metrax metric CosineSimilarity.""" + + def __init__(self): + super().__init__(metrax.CosineSimilarity) + + class DCGAtK(NnxWrapper): """An NNX class for the Metrax metric DCGAtK."""