diff --git a/jaxonnxruntime/onnx_ops/__init__.py b/jaxonnxruntime/onnx_ops/__init__.py index 6a0af7d..fc98398 100644 --- a/jaxonnxruntime/onnx_ops/__init__.py +++ b/jaxonnxruntime/onnx_ops/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import mean as mean diff --git a/jaxonnxruntime/onnx_ops/mean.py b/jaxonnxruntime/onnx_ops/mean.py new file mode 100644 index 0000000..60c10ab --- /dev/null +++ b/jaxonnxruntime/onnx_ops/mean.py @@ -0,0 +1,54 @@ +"""Define ONNX Mean operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Mean") +class Mean(handler.Handler): + """Implementation of the ONNX Mean operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) + + @classmethod + def version_1(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_1 Mean op.""" + cls._prepare(node, inputs, onnx_mean) + return onnx_mean + + @classmethod + def version_6(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_6 Mean op.""" + cls._prepare(node, inputs, onnx_mean) + return onnx_mean + + @classmethod + def version_8(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_8 Mean op.""" + cls._prepare(node, inputs, onnx_mean) + return onnx_mean + + @classmethod + def version_13(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_13 Mean op.""" + cls._prepare(node, inputs, onnx_mean) + return onnx_mean + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_mean(*input_args): + """Element-wise mean of input tensors.""" + # Stack all inputs and compute mean along the first axis + # This computes (A + B + C + ...) / N + stacked = jnp.stack(input_args, axis=0) + return jnp.mean(stacked, axis=0) diff --git a/tests/onnx_ops_test.py b/tests/onnx_ops_test.py index 33e21b8..4b0eec4 100644 --- a/tests/onnx_ops_test.py +++ b/tests/onnx_ops_test.py @@ -117,6 +117,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_matmul_') include_patterns.append('test_max_') include_patterns.append('test_maxpool_') +include_patterns.append('test_mean_') include_patterns.append('test_min_') include_patterns.append('test_mul_') include_patterns.append('test_neg_') @@ -268,4 +269,4 @@ class NodeTest(absltest.TestCase): if __name__ == '__main__': - absltest.main() + absltest.main() \ No newline at end of file