Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "A centralized JAX metrics library."
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.11"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
Expand Down Expand Up @@ -43,8 +43,8 @@ Issues = "https://github.com/google/metrax/issues"
dev = [
"absl-py>=2.3.1",
"jax[cpu]==0.6.2",
"jax_tpu_embedding==0.1.0.dev20250618",
"keras-hub",
"jax_tpu_embedding==0.1.0.dev20250618; platform_system != 'Darwin' and platform_system != 'Windows'",
"keras-hub;python_version<'3.13' and platform_system != 'Windows'",
"keras-rs>=0.2.1",
"nltk>=3.9.1",
"pytest>=8.4.1",
Expand All @@ -58,7 +58,7 @@ dev = [
docs = [
"myst_nb",
"sphinx-book-theme",
"scikit-learn==1.6.1",
"scikit-learn>=1.6.1",
]

[tool.ruff]
Expand Down
20 changes: 10 additions & 10 deletions src/metrax/classification_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def test_accuracy(self, y_true, y_pred, sample_weights):
sample_weights = np.ones_like(y_true)
metrax_accuracy = metrax.Accuracy.empty()
keras_accuracy = keras.metrics.Accuracy()
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
update = metrax.Accuracy.from_model_output(
predictions=logits,
labels=labels,
sample_weights=weights,
predictions=logits,
labels=labels,
sample_weights=weights,
)
metrax_accuracy = metrax_accuracy.merge(update)
keras_accuracy.update_state(labels, logits, weights)
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_precision(self, y_true, y_pred, threshold):
expected = keras_precision.result()

metric = None
for logits, labels in zip(y_pred, y_true):
for logits, labels in zip(y_pred, y_true, strict=False):
update = metrax.Precision.from_model_output(
predictions=logits,
labels=labels,
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_recall(self, y_true, y_pred, threshold):
expected = keras_recall.result()

metric = None
for logits, labels in zip(y_pred, y_true):
for logits, labels in zip(y_pred, y_true, strict=False):
update = metrax.Recall.from_model_output(
predictions=logits,
labels=labels,
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_aucpr(self, inputs, dtype):
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
update = metrax.AUCPR.from_model_output(
predictions=logits,
labels=labels,
Expand All @@ -221,7 +221,7 @@ def test_aucpr(self, inputs, dtype):
metric = update if metric is None else metric.merge(update)

keras_aucpr = keras.metrics.AUC(curve='PR')
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
keras_aucpr.update_state(labels, logits, sample_weight=weights)
expected = keras_aucpr.result()
np.testing.assert_allclose(
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_aucroc(self, inputs, dtype):
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
update = metrax.AUCROC.from_model_output(
predictions=logits,
labels=labels,
Expand All @@ -262,7 +262,7 @@ def test_aucroc(self, inputs, dtype):
metric = update if metric is None else metric.merge(update)

keras_aucroc = keras.metrics.AUC(curve='ROC')
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
keras_aucroc.update_state(labels, logits, sample_weight=weights)
expected = keras_aucroc.result()
np.testing.assert_allclose(
Expand Down
6 changes: 3 additions & 3 deletions src/metrax/nlp_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _get_single_n_grams(segment: list[str], order: int):
Returns:
A collections.Counter mapping n-gram tuples to their counts.
"""
return collections.Counter(zip(*[segment[i:] for i in range(order)]))
return collections.Counter(zip(*[segment[i:] for i in range(order)], strict=False))


def _get_ngrams(segment: list[str], max_order: int):
Expand Down Expand Up @@ -141,7 +141,7 @@ def from_model_output(
pred_length = 0
ref_length = 0

for pred, ref_list in zip(predictions, references):
for pred, ref_list in zip(predictions, references, strict=False):
pred = pred.split()
ref_list = [r.split() for r in ref_list]
pred_length += len(pred)
Expand Down Expand Up @@ -383,7 +383,7 @@ def from_model_output(
total_f1 = 0.0
num_examples = 0.0

for pred_str, ref_str in zip(predictions, references):
for pred_str, ref_str in zip(predictions, references, strict=False):
pred_tokens = pred_str.split()
ref_tokens = ref_str.split()

Expand Down
14 changes: 8 additions & 6 deletions src/metrax/nlp_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
import keras_hub
import metrax
import numpy as np
import sys
import pytest

np.random.seed(42)


@pytest.mark.skipif(sys.version_info >= (3, 13), reason="keras_hub not available for Python 3.13+")
class NlpMetricsTest(parameterized.TestCase):

def test_bleu_empty(self):
"""Tests the `empty` method of the `BLEU` class."""
m = metrax.BLEU.empty()
Expand Down Expand Up @@ -90,7 +92,7 @@ def test_bleu_merge(self):
keras_metric = keras_hub.metrics.Bleu()
keras_metric.update_state(references, predictions)
metrax_metric = None
for ref_list, pred in zip(references, predictions):
for ref_list, pred in zip(references, predictions, strict=False):
update = metrax.BLEU.from_model_output([pred], [ref_list])
metrax_metric = (
update if metrax_metric is None else metrax_metric.merge(update)
Expand Down Expand Up @@ -183,7 +185,7 @@ def test_rouge_merge(self, metrax_rouge, keras_rouge):
keras_metric_array = jnp.stack(list(keras_metric.result().values()))

metrax_metric = None
for ref, pred in zip(references, predictions):
for ref, pred in zip(references, predictions, strict=False):
update = metrax_rouge.from_model_output([pred], [ref])
metrax_metric = (
update if metrax_metric is None else metrax_metric.merge(update)
Expand Down Expand Up @@ -249,7 +251,7 @@ def test_perplexity(self, y_true, y_pred, sample_weights, from_logits):
"""Test that `Perplexity` Metric computes correct values."""
keras_metric = keras_hub.metrics.Perplexity(from_logits=from_logits)
metrax_metric = None
for index, (labels, logits) in enumerate(zip(y_true, y_pred)):
for index, (labels, logits) in enumerate(zip(y_true, y_pred, strict=False)):
weights = sample_weights[index] if sample_weights is not None else None
keras_metric.update_state(labels, logits, sample_weight=weights)
update = metrax.Perplexity.from_model_output(
Expand Down Expand Up @@ -287,7 +289,7 @@ def test_wer(self):

metrax_token_metric = None
keras_metric = keras_hub.metrics.EditDistance(normalize=True)
for pred, ref in zip(tokenized_preds, tokenized_refs):
for pred, ref in zip(tokenized_preds, tokenized_refs, strict=False):
metrax_update = metrax.WER.from_model_output(pred,ref)
keras_metric.update_state(ref, pred)
metrax_token_metric = metrax_update if metrax_token_metric is None else metrax_token_metric.merge(metrax_update)
Expand All @@ -301,7 +303,7 @@ def test_wer(self):
)

metrax_string_metric = None
for pred, ref in zip(string_preds, string_refs):
for pred, ref in zip(string_preds, string_refs, strict=False):
update = metrax.WER.from_model_output(predictions=pred, references=ref)
metrax_string_metric = update if metrax_string_metric is None else metrax_string_metric.merge(update)

Expand Down
2 changes: 1 addition & 1 deletion src/metrax/nnx/nnx_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_metric_update_and_compute(self, y_true, y_pred, sample_weights):
sample_weights = np.ones_like(y_true)

nnx_metric = metrax.nnx.MSE()
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
nnx_metric.update(
predictions=logits,
labels=labels,
Expand Down
18 changes: 9 additions & 9 deletions src/metrax/regression_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def compute_metric(logits, labels):

# 5. Verify against reference (Keras reference remains the same).
keras_r2 = keras.metrics.R2Score()
for labels, logits in zip(OUTPUT_LABELS, OUTPUT_PREDS):
for labels, logits in zip(OUTPUT_LABELS, OUTPUT_PREDS, strict=False):
keras_r2.update_state(
labels[:, jnp.newaxis],
logits[:, jnp.newaxis],
Expand Down Expand Up @@ -128,7 +128,7 @@ def sharded_r2(logits, labels):
metric = metric_on_host.reduce()

keras_r2 = keras.metrics.R2Score()
for labels, logits in zip(y_true, y_pred):
for labels, logits in zip(y_true, y_pred, strict=False):
keras_r2.update_state(
labels[:, jnp.newaxis],
logits[:, jnp.newaxis],
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_mae(self, y_true, y_pred, sample_weights):
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
update = metrax.MAE.from_model_output(
predictions=logits,
labels=labels,
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_mse(self, y_true, y_pred, sample_weights):
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
update = metrax.MSE.from_model_output(
predictions=logits,
labels=labels,
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_rmse(self, y_true, y_pred, sample_weights):

metric = None
keras_rmse = keras.metrics.RootMeanSquaredError()
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
update = metrax.RMSE.from_model_output(
predictions=logits,
labels=labels,
Expand Down Expand Up @@ -299,7 +299,7 @@ def test_msle(self, y_true, y_pred, sample_weights):
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
update = metrax.MSLE.from_model_output(
predictions=logits,
labels=labels,
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_rmsle(self, y_true, y_pred, sample_weights):
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
update = metrax.RMSLE.from_model_output(
predictions=logits,
labels=labels,
Expand Down Expand Up @@ -380,7 +380,7 @@ def test_rsquared(self, y_true, y_pred, sample_weights):

metric = None
keras_r2 = keras.metrics.R2Score()
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False):
update = metrax.RSQUARED.from_model_output(
predictions=logits,
labels=labels,
Expand Down Expand Up @@ -416,7 +416,7 @@ def test_spearman(self, y_true, y_pred):
y_pred = y_pred.astype(y_true.dtype)

metric = None
for labels, logits in zip(y_true, y_pred):
for labels, logits in zip(y_true, y_pred, strict=False):
update = metrax.SpearmanRankCorrelation.from_model_output(
predictions=logits,
labels=labels,
Expand Down