Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jaxonnxruntime/onnx_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import mean as mean
54 changes: 54 additions & 0 deletions jaxonnxruntime/onnx_ops/mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Define ONNX Mean operator."""
# pylint: disable=unused-argument
# pylint: disable=g-explicit-length-test
import functools
from collections.abc import Callable, Sequence
from typing import Any

import jax
from jax import numpy as jnp
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node
from jaxonnxruntime.onnx_ops import onnx_ops_utils


@handler.register_op("Mean")
class Mean(handler.Handler):
"""Implementation of the ONNX Mean operator."""

@classmethod
def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any):
onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl)

@classmethod
def version_1(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_1 Mean op."""
cls._prepare(node, inputs, onnx_mean)
return onnx_mean

@classmethod
def version_6(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_6 Mean op."""
cls._prepare(node, inputs, onnx_mean)
return onnx_mean

@classmethod
def version_8(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_8 Mean op."""
cls._prepare(node, inputs, onnx_mean)
return onnx_mean

@classmethod
def version_13(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_13 Mean op."""
cls._prepare(node, inputs, onnx_mean)
return onnx_mean


@functools.partial(jax.jit, static_argnames=())
def onnx_mean(*input_args):
"""Element-wise mean of input tensors."""
# Stack all inputs and compute mean along the first axis
# This computes (A + B + C + ...) / N
stacked = jnp.stack(input_args, axis=0)
return jnp.mean(stacked, axis=0)
3 changes: 2 additions & 1 deletion tests/onnx_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class NodeTest(absltest.TestCase):
include_patterns.append('test_matmul_')
include_patterns.append('test_max_')
include_patterns.append('test_maxpool_')
include_patterns.append('test_mean_')
include_patterns.append('test_min_')
include_patterns.append('test_mul_')
include_patterns.append('test_neg_')
Expand Down Expand Up @@ -268,4 +269,4 @@ class NodeTest(absltest.TestCase):


if __name__ == '__main__':
absltest.main()
absltest.main()