From 83f3b461d314b90b1b3fa0d6d1a5955e7d62ac8d Mon Sep 17 00:00:00 2001 From: htjb Date: Wed, 5 Nov 2025 12:44:52 +0000 Subject: [PATCH 1/4] implementation of the mean opp --- jaxonnxruntime/onnx_ops/mean.py | 51 +++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 jaxonnxruntime/onnx_ops/mean.py diff --git a/jaxonnxruntime/onnx_ops/mean.py b/jaxonnxruntime/onnx_ops/mean.py new file mode 100644 index 0000000..bc4aa0c --- /dev/null +++ b/jaxonnxruntime/onnx_ops/mean.py @@ -0,0 +1,51 @@ +"""Define ONNX Mean operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Mean") +class Mean(handler.Handler): + """Implementation of the ONNX Mean operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) + + @classmethod + def version_1(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_1 Mean op.""" + cls._prepare(node, inputs, onnx_mean) + return onnx_mean + + @classmethod + def version_6(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_6 Mean op.""" + cls._prepare(node, inputs, onnx_mean) + return onnx_mean + + @classmethod + def version_8(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_8 Mean op.""" + cls._prepare(node, inputs, onnx_mean) + return onnx_mean + + @classmethod + def version_13(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_13 Mean op.""" + cls._prepare(node, inputs, onnx_mean) + return onnx_mean + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_mean(*input_args): + stacked = jnp.stack(input_args, axis=0) + return jnp.mean(stacked, axis=0) From e4ebf43d1b90502646739bad6a002faabb814a31 Mon Sep 17 00:00:00 2001 From: htjb Date: Wed, 5 Nov 2025 12:46:00 +0000 Subject: [PATCH 2/4] adding mean as mean --- jaxonnxruntime/onnx_ops/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxonnxruntime/onnx_ops/__init__.py b/jaxonnxruntime/onnx_ops/__init__.py index 6a0af7d..fc98398 100644 --- a/jaxonnxruntime/onnx_ops/__init__.py +++ b/jaxonnxruntime/onnx_ops/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import mean as mean From 76334f62028f926cb09fb4614cfd4c2b00d33bfe Mon Sep 17 00:00:00 2001 From: htjb Date: Wed, 5 Nov 2025 12:46:15 +0000 Subject: [PATCH 3/4] pointing at the tests for mean --- tests/onnx_ops_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/onnx_ops_test.py b/tests/onnx_ops_test.py index 33e21b8..4b0eec4 100644 --- a/tests/onnx_ops_test.py +++ b/tests/onnx_ops_test.py @@ -117,6 +117,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_matmul_') include_patterns.append('test_max_') include_patterns.append('test_maxpool_') +include_patterns.append('test_mean_') include_patterns.append('test_min_') include_patterns.append('test_mul_') include_patterns.append('test_neg_') @@ -268,4 +269,4 @@ class NodeTest(absltest.TestCase): if __name__ == '__main__': - absltest.main() + absltest.main() \ No newline at end of file From 7ba5079160840aebb82d06b5ef13f112555a3506 Mon Sep 17 00:00:00 2001 From: htjb Date: Wed, 5 Nov 2025 12:48:19 +0000 Subject: [PATCH 4/4] adding a comment to function --- jaxonnxruntime/onnx_ops/mean.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jaxonnxruntime/onnx_ops/mean.py b/jaxonnxruntime/onnx_ops/mean.py index bc4aa0c..60c10ab 100644 --- a/jaxonnxruntime/onnx_ops/mean.py +++ b/jaxonnxruntime/onnx_ops/mean.py @@ -47,5 +47,8 @@ def version_13(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable @functools.partial(jax.jit, static_argnames=()) def onnx_mean(*input_args): + """Element-wise mean of input tensors.""" + # Stack all inputs and compute mean along the first axis + # This computes (A + B + C + ...) / N stacked = jnp.stack(input_args, axis=0) return jnp.mean(stacked, axis=0)