Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
jeffcarp
left a comment
There was a problem hiding this comment.
Thanks for the contribution Pavan! Left some feedback.
| """ | ||
|
|
||
| if from_logits: | ||
| predictions = jax.nn.softmax(predictions, axis=-1) |
There was a problem hiding this comment.
This is inconsistent with the other metrics where _convert_logits_to_probabilities is called
| ValueError: If type of `labels` is wrong or the shapes of `predictions` | ||
| and `labels` are incompatible. | ||
| """ | ||
| predictions = _convert_logits_to_probabilities(predictions, from_logits) |
There was a problem hiding this comment.
Can you update these so it only calls _convert_logits_to_probabilities if from_logits is true?
| @classmethod | ||
| def from_model_output( | ||
| cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5 | ||
| cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5, from_logits: bool = False |
There was a problem hiding this comment.
Can you update the docstrings (here and below)?
| raise ValueError('The "Threshold" value must be between 0 and 1.') | ||
|
|
||
| # If the predictions are logits, convert them to probabilities | ||
|
|
| ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5), | ||
| ) | ||
| def test_precision(self, y_true, y_pred, threshold): | ||
| def test_precision(self, y_true, y_pred, threshold,from_logits=False): |
| ), | ||
| ) | ||
| def test_aucpr(self, inputs, dtype): | ||
| def test_aucpr(self, inputs, dtype, from_logits=False): |
There was a problem hiding this comment.
Remove default here as well
| def test_aucpr(self, inputs, dtype, from_logits=False): | ||
| """Test that `AUC-PR` Metric computes correct values.""" | ||
| y_true, y_pred, sample_weights = inputs | ||
| y_true, y_pred, sample_weights, from_logits = inputs |
There was a problem hiding this comment.
This is shadowing the from_logits variable - dedupe
|
|
||
| keras_aucpr = keras.metrics.AUC(curve='PR') | ||
| if from_logits: | ||
| y_pred = jax.nn.softmax(y_pred, axis=-1) |
There was a problem hiding this comment.
Please match 2 spaces per indentation style here and everywhere else
| def test_aucroc(self, inputs, dtype, from_logits=False): | ||
| """Test that `AUC-ROC` Metric computes correct values.""" | ||
| y_true, y_pred, sample_weights = inputs | ||
| y_true, y_pred, sample_weights,from_logits = inputs |
| ), | ||
| ) | ||
| def test_aucroc(self, inputs, dtype): | ||
| def test_aucroc(self, inputs, dtype, from_logits=False): |
This pull request introduces a new boolean parameter from_logits to Metrax classification metrics, enabling users to pass raw model logits directly without manually converting them to probabilities.
It directly resolves issue #105
Background:
Currently, users must apply softmax activation on logits before passing predictions to metrics, which adds boilerplate and can lead to errors.
What’s New:
Added from_logits flag to classification metrics: Precision, Recall, F1Score, FBetaScore, Accuracy.
When from_logits=True, Metrax automatically applies:
Softmax for multi-class logits
Backward compatibility preserved (from_logits=False by default).
Test suites updated to cover both activated and raw logits inputs.
Benefits:
Simplifies user workflow by removing manual activation step.
Reduces common user errors with logits processing.
Improves consistency and usability.
Tests:
Added parameterized tests for from_logits=True scenarios.
Verified numerical equivalence with ground truth metrics.
Passed tests across multiple dtypes and classification settings.