Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/metrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Average = base.Average
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
BLEU = nlp_metrics.BLEU
CosineSimilarity = image_metrics.CosineSimilarity
DCGAtK = ranking_metrics.DCGAtK
Dice = image_metrics.Dice
FBetaScore = classification_metrics.FBetaScore
Expand Down Expand Up @@ -57,6 +58,7 @@
"Average",
"AveragePrecisionAtK",
"BLEU",
"CosineSimilarity",
"DCGAtK",
"Dice",
"FBetaScore",
Expand Down
40 changes: 40 additions & 0 deletions src/metrax/image_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,43 @@ def compute(self) -> jax.Array:
"""Returns the final Dice coefficient."""
epsilon = 1e-7
return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon)


@flax.struct.dataclass
class CosineSimilarity(base.Average):
r"""Calculates the Cosine Similarity between two arrays.

The Cosine Similarity is defined as the dot product of the vectors divided
by the product of their magnitudes (norms).

.. math::
cos_{sim}(x,y) = \frac{x \cdot y}{||x|| * ||y||}
"""

@classmethod
def from_model_output(
cls,
predictions: jax.Array,
targets: jax.Array,
axis: int = -1,
) -> 'CosineSimilarity':
"""Creates a CosineSimilarity instance.

Args:
predictions: A floating point array of the predictions. The shape should
be (batch_size,).
targets: A floating point array of the targets. The shape should be
(batch_size,).
axis: The axis to compute the norm over.

Returns:
A `CosineSimilarity` instance.
"""
dot_product = (predictions * targets).sum(axis=axis)
predictions_norm = jnp.linalg.norm(predictions, ord=2, axis=axis)
targets_norm = jnp.linalg.norm(targets, ord=2, axis=axis)

cosine_similarity = dot_product / (predictions_norm * targets_norm)

return super().from_model_output(values=cosine_similarity)

41 changes: 41 additions & 0 deletions src/metrax/image_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,47 @@ def test_dice(self, y_true, y_pred):

np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5)

@parameterized.named_parameters(
(
'cosine_similarity_basic_f32',
PREDS_1,
TARGETS_1,
),
(
'cosine_similarity_multichannel_norm',
PREDS_2,
TARGETS_2,
),
(
'cosine_similarity_uint8_range_single_channel',
PREDS_3,
TARGETS_3,
),
(
'cosine_similarity_identical_images',
PREDS_4,
TARGETS_4,
),
(
'cosine_similarity_large_batch',
PREDS_6,
TARGETS_6,
),
)
def test_cosine_similarity_against_keras(self, predictions, targets):
"""Test that CosineSimilarity computes expected values."""
predictions = jnp.array(predictions)
targets = jnp.array(targets)
keras_cosine_similarity_metric = keras.metrics.CosineSimilarity()
keras_cosine_similarity_metric.update_state(predictions, targets)
expected = keras_cosine_similarity_metric.result()

metric = metrax.CosineSimilarity.from_model_output(
predictions=predictions, targets=targets
)
result = metric.compute()

np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5)

if __name__ == '__main__':
absltest.main()
8 changes: 8 additions & 0 deletions src/metrax/metrax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ class MetraxTest(parameterized.TestCase):
'ks': KS,
},
),
(
'cosinesimilarity',
metrax.CosineSimilarity,
{
'predictions': OUTPUT_LABELS,
'targets': OUTPUT_PREDS,
},
),
(
'dcgAtK',
metrax.DCGAtK,
Expand Down
2 changes: 2 additions & 0 deletions src/metrax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Average = nnx_metrics.Average
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
BLEU = nnx_metrics.BLEU
CosineSimilarity = nnx_metrics.CosineSimilarity
DCGAtK = nnx_metrics.DCGAtK
Dice = nnx_metrics.Dice
FBetaScore = nnx_metrics.FBetaScore
Expand Down Expand Up @@ -50,6 +51,7 @@
"Average",
"AveragePrecisionAtK",
"BLEU",
"CosineSimilarity",
"DCGAtK",
"Dice",
"FBetaScore",
Expand Down
7 changes: 7 additions & 0 deletions src/metrax/nnx/nnx_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def __init__(self):
super().__init__(metrax.BLEU)


class CosineSimilarity(NnxWrapper):
"""An NNX class for the Metrax metric CosineSimilarity."""

def __init__(self):
super().__init__(metrax.CosineSimilarity)


class DCGAtK(NnxWrapper):
"""An NNX class for the Metrax metric DCGAtK."""

Expand Down