From eeaadec8f9c6222d4986314c3ee826281177e411 Mon Sep 17 00:00:00 2001 From: Nikola Savic Date: Sat, 31 Jan 2026 08:08:51 +0100 Subject: [PATCH 1/3] Update pyproject.toml for Python 3.11+ and platform-specific dependencies --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0bf5e95..16afec6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "A centralized JAX metrics library." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.11" classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", @@ -43,8 +43,8 @@ Issues = "https://github.com/google/metrax/issues" dev = [ "absl-py>=2.3.1", "jax[cpu]==0.6.2", - "jax_tpu_embedding==0.1.0.dev20250618", - "keras-hub", + "jax_tpu_embedding==0.1.0.dev20250618; platform_system != 'Darwin' and platform_system != 'Windows'", + "keras-hub;python_version<'3.13' and platform_system != 'Windows'", "keras-rs>=0.2.1", "nltk>=3.9.1", "pytest>=8.4.1", @@ -58,7 +58,7 @@ dev = [ docs = [ "myst_nb", "sphinx-book-theme", - "scikit-learn==1.6.1", + "scikit-learn>=1.6.1", ] [tool.ruff] From 1f0bab9ecad0c3ad2276e5ce50aee7815ed743e0 Mon Sep 17 00:00:00 2001 From: Nikola Savic Date: Sat, 31 Jan 2026 08:09:22 +0100 Subject: [PATCH 2/3] add pytest skip for keras_hub on Python 3.13+ --- src/metrax/nlp_metrics_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/metrax/nlp_metrics_test.py b/src/metrax/nlp_metrics_test.py index 8152553..ab156de 100644 --- a/src/metrax/nlp_metrics_test.py +++ b/src/metrax/nlp_metrics_test.py @@ -23,12 +23,14 @@ import keras_hub import metrax import numpy as np +import sys +import pytest np.random.seed(42) +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="keras_hub not available for Python 3.13+") class NlpMetricsTest(parameterized.TestCase): - def test_bleu_empty(self): """Tests the `empty` method of the `BLEU` class.""" m = metrax.BLEU.empty() From 5450d5d198752763666b1d9a54608e552f3e4177 Mon Sep 17 00:00:00 2001 From: Nikola Savic Date: Sat, 31 Jan 2026 08:40:33 +0100 Subject: [PATCH 3/3] Add explicit strict=False to all zip() calls to address ruff warnings --- src/metrax/classification_metrics_test.py | 20 ++++++++++---------- src/metrax/nlp_metrics.py | 6 +++--- src/metrax/nlp_metrics_test.py | 10 +++++----- src/metrax/nnx/nnx_wrapper_test.py | 2 +- src/metrax/regression_metrics_test.py | 18 +++++++++--------- 5 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/metrax/classification_metrics_test.py b/src/metrax/classification_metrics_test.py index 2e54451..c4a8f46 100644 --- a/src/metrax/classification_metrics_test.py +++ b/src/metrax/classification_metrics_test.py @@ -99,11 +99,11 @@ def test_accuracy(self, y_true, y_pred, sample_weights): sample_weights = np.ones_like(y_true) metrax_accuracy = metrax.Accuracy.empty() keras_accuracy = keras.metrics.Accuracy() - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): update = metrax.Accuracy.from_model_output( - predictions=logits, - labels=labels, - sample_weights=weights, + predictions=logits, + labels=labels, + sample_weights=weights, ) metrax_accuracy = metrax_accuracy.merge(update) keras_accuracy.update_state(labels, logits, weights) @@ -139,7 +139,7 @@ def test_precision(self, y_true, y_pred, threshold): expected = keras_precision.result() metric = None - for logits, labels in zip(y_pred, y_true): + for logits, labels in zip(y_pred, y_true, strict=False): update = metrax.Precision.from_model_output( predictions=logits, labels=labels, @@ -178,7 +178,7 @@ def test_recall(self, y_true, y_pred, threshold): expected = keras_recall.result() metric = None - for logits, labels in zip(y_pred, y_true): + for logits, labels in zip(y_pred, y_true, strict=False): update = metrax.Recall.from_model_output( predictions=logits, labels=labels, @@ -212,7 +212,7 @@ def test_aucpr(self, inputs, dtype): sample_weights = np.ones_like(y_true) metric = None - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): update = metrax.AUCPR.from_model_output( predictions=logits, labels=labels, @@ -221,7 +221,7 @@ def test_aucpr(self, inputs, dtype): metric = update if metric is None else metric.merge(update) keras_aucpr = keras.metrics.AUC(curve='PR') - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): keras_aucpr.update_state(labels, logits, sample_weight=weights) expected = keras_aucpr.result() np.testing.assert_allclose( @@ -253,7 +253,7 @@ def test_aucroc(self, inputs, dtype): sample_weights = np.ones_like(y_true) metric = None - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): update = metrax.AUCROC.from_model_output( predictions=logits, labels=labels, @@ -262,7 +262,7 @@ def test_aucroc(self, inputs, dtype): metric = update if metric is None else metric.merge(update) keras_aucroc = keras.metrics.AUC(curve='ROC') - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): keras_aucroc.update_state(labels, logits, sample_weight=weights) expected = keras_aucroc.result() np.testing.assert_allclose( diff --git a/src/metrax/nlp_metrics.py b/src/metrax/nlp_metrics.py index 4ce504c..7ff60fa 100644 --- a/src/metrax/nlp_metrics.py +++ b/src/metrax/nlp_metrics.py @@ -34,7 +34,7 @@ def _get_single_n_grams(segment: list[str], order: int): Returns: A collections.Counter mapping n-gram tuples to their counts. """ - return collections.Counter(zip(*[segment[i:] for i in range(order)])) + return collections.Counter(zip(*[segment[i:] for i in range(order)], strict=False)) def _get_ngrams(segment: list[str], max_order: int): @@ -141,7 +141,7 @@ def from_model_output( pred_length = 0 ref_length = 0 - for pred, ref_list in zip(predictions, references): + for pred, ref_list in zip(predictions, references, strict=False): pred = pred.split() ref_list = [r.split() for r in ref_list] pred_length += len(pred) @@ -383,7 +383,7 @@ def from_model_output( total_f1 = 0.0 num_examples = 0.0 - for pred_str, ref_str in zip(predictions, references): + for pred_str, ref_str in zip(predictions, references, strict=False): pred_tokens = pred_str.split() ref_tokens = ref_str.split() diff --git a/src/metrax/nlp_metrics_test.py b/src/metrax/nlp_metrics_test.py index ab156de..34e25b8 100644 --- a/src/metrax/nlp_metrics_test.py +++ b/src/metrax/nlp_metrics_test.py @@ -92,7 +92,7 @@ def test_bleu_merge(self): keras_metric = keras_hub.metrics.Bleu() keras_metric.update_state(references, predictions) metrax_metric = None - for ref_list, pred in zip(references, predictions): + for ref_list, pred in zip(references, predictions, strict=False): update = metrax.BLEU.from_model_output([pred], [ref_list]) metrax_metric = ( update if metrax_metric is None else metrax_metric.merge(update) @@ -185,7 +185,7 @@ def test_rouge_merge(self, metrax_rouge, keras_rouge): keras_metric_array = jnp.stack(list(keras_metric.result().values())) metrax_metric = None - for ref, pred in zip(references, predictions): + for ref, pred in zip(references, predictions, strict=False): update = metrax_rouge.from_model_output([pred], [ref]) metrax_metric = ( update if metrax_metric is None else metrax_metric.merge(update) @@ -251,7 +251,7 @@ def test_perplexity(self, y_true, y_pred, sample_weights, from_logits): """Test that `Perplexity` Metric computes correct values.""" keras_metric = keras_hub.metrics.Perplexity(from_logits=from_logits) metrax_metric = None - for index, (labels, logits) in enumerate(zip(y_true, y_pred)): + for index, (labels, logits) in enumerate(zip(y_true, y_pred, strict=False)): weights = sample_weights[index] if sample_weights is not None else None keras_metric.update_state(labels, logits, sample_weight=weights) update = metrax.Perplexity.from_model_output( @@ -289,7 +289,7 @@ def test_wer(self): metrax_token_metric = None keras_metric = keras_hub.metrics.EditDistance(normalize=True) - for pred, ref in zip(tokenized_preds, tokenized_refs): + for pred, ref in zip(tokenized_preds, tokenized_refs, strict=False): metrax_update = metrax.WER.from_model_output(pred,ref) keras_metric.update_state(ref, pred) metrax_token_metric = metrax_update if metrax_token_metric is None else metrax_token_metric.merge(metrax_update) @@ -303,7 +303,7 @@ def test_wer(self): ) metrax_string_metric = None - for pred, ref in zip(string_preds, string_refs): + for pred, ref in zip(string_preds, string_refs, strict=False): update = metrax.WER.from_model_output(predictions=pred, references=ref) metrax_string_metric = update if metrax_string_metric is None else metrax_string_metric.merge(update) diff --git a/src/metrax/nnx/nnx_wrapper_test.py b/src/metrax/nnx/nnx_wrapper_test.py index d5a48b5..b7a8746 100644 --- a/src/metrax/nnx/nnx_wrapper_test.py +++ b/src/metrax/nnx/nnx_wrapper_test.py @@ -79,7 +79,7 @@ def test_metric_update_and_compute(self, y_true, y_pred, sample_weights): sample_weights = np.ones_like(y_true) nnx_metric = metrax.nnx.MSE() - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): nnx_metric.update( predictions=logits, labels=labels, diff --git a/src/metrax/regression_metrics_test.py b/src/metrax/regression_metrics_test.py index cb75fda..adc2f25 100644 --- a/src/metrax/regression_metrics_test.py +++ b/src/metrax/regression_metrics_test.py @@ -82,7 +82,7 @@ def compute_metric(logits, labels): # 5. Verify against reference (Keras reference remains the same). keras_r2 = keras.metrics.R2Score() - for labels, logits in zip(OUTPUT_LABELS, OUTPUT_PREDS): + for labels, logits in zip(OUTPUT_LABELS, OUTPUT_PREDS, strict=False): keras_r2.update_state( labels[:, jnp.newaxis], logits[:, jnp.newaxis], @@ -128,7 +128,7 @@ def sharded_r2(logits, labels): metric = metric_on_host.reduce() keras_r2 = keras.metrics.R2Score() - for labels, logits in zip(y_true, y_pred): + for labels, logits in zip(y_true, y_pred, strict=False): keras_r2.update_state( labels[:, jnp.newaxis], logits[:, jnp.newaxis], @@ -178,7 +178,7 @@ def test_mae(self, y_true, y_pred, sample_weights): sample_weights = np.ones_like(y_true) metric = None - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): update = metrax.MAE.from_model_output( predictions=logits, labels=labels, @@ -220,7 +220,7 @@ def test_mse(self, y_true, y_pred, sample_weights): sample_weights = np.ones_like(y_true) metric = None - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): update = metrax.MSE.from_model_output( predictions=logits, labels=labels, @@ -263,7 +263,7 @@ def test_rmse(self, y_true, y_pred, sample_weights): metric = None keras_rmse = keras.metrics.RootMeanSquaredError() - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): update = metrax.RMSE.from_model_output( predictions=logits, labels=labels, @@ -299,7 +299,7 @@ def test_msle(self, y_true, y_pred, sample_weights): sample_weights = np.ones_like(y_true) metric = None - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): update = metrax.MSLE.from_model_output( predictions=logits, labels=labels, @@ -339,7 +339,7 @@ def test_rmsle(self, y_true, y_pred, sample_weights): sample_weights = np.ones_like(y_true) metric = None - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): update = metrax.RMSLE.from_model_output( predictions=logits, labels=labels, @@ -380,7 +380,7 @@ def test_rsquared(self, y_true, y_pred, sample_weights): metric = None keras_r2 = keras.metrics.R2Score() - for labels, logits, weights in zip(y_true, y_pred, sample_weights): + for labels, logits, weights in zip(y_true, y_pred, sample_weights, strict=False): update = metrax.RSQUARED.from_model_output( predictions=logits, labels=labels, @@ -416,7 +416,7 @@ def test_spearman(self, y_true, y_pred): y_pred = y_pred.astype(y_true.dtype) metric = None - for labels, logits in zip(y_true, y_pred): + for labels, logits in zip(y_true, y_pred, strict=False): update = metrax.SpearmanRankCorrelation.from_model_output( predictions=logits, labels=labels,