From fc13fd192647443ae8acdd2aa4682f38e6fb64ca Mon Sep 17 00:00:00 2001 From: Pravin-Dangol Date: Thu, 7 Aug 2025 16:21:06 -0700 Subject: [PATCH] Add ops for elu, hardsigmoid, shrink, sign, and tan. Add new ops Update default value in shrink Add ops for elu, hardsigmoid, shrink, sign, and tan. Update default value in shrink --- jaxonnxruntime/__init__.py | 5 +++ jaxonnxruntime/onnx_ops/elu.py | 48 ++++++++++++++++++++++++++ jaxonnxruntime/onnx_ops/hardsigmoid.py | 48 ++++++++++++++++++++++++++ jaxonnxruntime/onnx_ops/shrink.py | 40 +++++++++++++++++++++ jaxonnxruntime/onnx_ops/sign.py | 41 ++++++++++++++++++++++ jaxonnxruntime/onnx_ops/tan.py | 41 ++++++++++++++++++++++ tests/onnx_ops_test.py | 5 +++ 7 files changed, 228 insertions(+) create mode 100644 jaxonnxruntime/onnx_ops/elu.py create mode 100644 jaxonnxruntime/onnx_ops/hardsigmoid.py create mode 100644 jaxonnxruntime/onnx_ops/shrink.py create mode 100644 jaxonnxruntime/onnx_ops/sign.py create mode 100644 jaxonnxruntime/onnx_ops/tan.py diff --git a/jaxonnxruntime/__init__.py b/jaxonnxruntime/__init__.py index fe5349f..7fa74c1 100644 --- a/jaxonnxruntime/__init__.py +++ b/jaxonnxruntime/__init__.py @@ -50,6 +50,7 @@ from jaxonnxruntime.onnx_ops import div from jaxonnxruntime.onnx_ops import dropout from jaxonnxruntime.onnx_ops import einsum +from jaxonnxruntime.onnx_ops import elu from jaxonnxruntime.onnx_ops import equal from jaxonnxruntime.onnx_ops import erf from jaxonnxruntime.onnx_ops import exp @@ -61,6 +62,7 @@ from jaxonnxruntime.onnx_ops import globalaveragepool from jaxonnxruntime.onnx_ops import greater from jaxonnxruntime.onnx_ops import greaterorequal +from jaxonnxruntime.onnx_ops import hardsigmoid from jaxonnxruntime.onnx_ops import identity from jaxonnxruntime.onnx_ops import if_op from jaxonnxruntime.onnx_ops import leakyrelu @@ -94,7 +96,9 @@ from jaxonnxruntime.onnx_ops import scatternd from jaxonnxruntime.onnx_ops import selu from jaxonnxruntime.onnx_ops import shape +from jaxonnxruntime.onnx_ops import shrink from jaxonnxruntime.onnx_ops import sigmoid +from jaxonnxruntime.onnx_ops import sign from jaxonnxruntime.onnx_ops import sin from jaxonnxruntime.onnx_ops import sinh from jaxonnxruntime.onnx_ops import slice @@ -105,6 +109,7 @@ from jaxonnxruntime.onnx_ops import squeeze from jaxonnxruntime.onnx_ops import sub from jaxonnxruntime.onnx_ops import sum +from jaxonnxruntime.onnx_ops import tan from jaxonnxruntime.onnx_ops import tanh from jaxonnxruntime.onnx_ops import tile from jaxonnxruntime.onnx_ops import topk diff --git a/jaxonnxruntime/onnx_ops/elu.py b/jaxonnxruntime/onnx_ops/elu.py new file mode 100644 index 0000000..ead0ef0 --- /dev/null +++ b/jaxonnxruntime/onnx_ops/elu.py @@ -0,0 +1,48 @@ +"""Define ONNX Elu operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Elu") +class Elu(handler.Handler): + """Implementation of the ONNX Elu operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + node.attrs_dict['alpha'] = node.attrs.get( + 'alpha', 1.0 + ) + + @classmethod + def version_1(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_1 Elu op.""" + cls._prepare(node, inputs, onnx_elu) + return onnx_elu + + @classmethod + def version_6(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_6 Elu op.""" + cls._prepare(node, inputs, onnx_elu) + return onnx_elu + + @classmethod + def version_22(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_22 Elu op.""" + cls._prepare(node, inputs, onnx_elu) + return onnx_elu + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_elu(*input_args, alpha): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Elu for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jax.nn.elu(data, alpha) diff --git a/jaxonnxruntime/onnx_ops/hardsigmoid.py b/jaxonnxruntime/onnx_ops/hardsigmoid.py new file mode 100644 index 0000000..7877b6a --- /dev/null +++ b/jaxonnxruntime/onnx_ops/hardsigmoid.py @@ -0,0 +1,48 @@ +"""Define ONNX HardSigmoid operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("HardSigmoid") +class HardSigmoid(handler.Handler): + """Implementation of the ONNX HardSigmoid operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + node.attrs_dict["alpha"] = node.attrs.get("alpha", 0.2) + node.attrs_dict["beta"] = node.attrs.get("beta", 0.5) + + @classmethod + def version_1(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_1 HardSigmoid op.""" + cls._prepare(node, inputs, onnx_hardsigmoid) + return onnx_hardsigmoid + + @classmethod + def version_6(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_6 HardSigmoid op.""" + cls._prepare(node, inputs, onnx_hardsigmoid) + return onnx_hardsigmoid + + @classmethod + def version_22(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_22 HardSigmoid op.""" + cls._prepare(node, inputs, onnx_hardsigmoid) + return onnx_hardsigmoid + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_hardsigmoid(*input_args, alpha, beta): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#HardSigmoid for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jnp.maximum(0, jnp.minimum(1, data * alpha + beta)).astype(data.dtype) diff --git a/jaxonnxruntime/onnx_ops/shrink.py b/jaxonnxruntime/onnx_ops/shrink.py new file mode 100644 index 0000000..02b4d54 --- /dev/null +++ b/jaxonnxruntime/onnx_ops/shrink.py @@ -0,0 +1,40 @@ +"""Define ONNX Shrink operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Shrink") +class Shrink(handler.Handler): + """Implementation of the ONNX Shrink operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + node.attrs_dict['bias'] = node.attrs.get('bias', 0.0) + node.attrs_dict['lambd'] = node.attrs.get('lambd', 0.5) + + @classmethod + def version_9(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_9 Shrink op.""" + cls._prepare(node, inputs, onnx_shrink) + return onnx_shrink + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_shrink(*input_args, bias, lambd): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Shrink for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jnp.where( + data < -lambd, + data + bias, + jnp.where(data > lambd, data - bias, 0), + ).astype(data.dtype) diff --git a/jaxonnxruntime/onnx_ops/sign.py b/jaxonnxruntime/onnx_ops/sign.py new file mode 100644 index 0000000..dfaa8c8 --- /dev/null +++ b/jaxonnxruntime/onnx_ops/sign.py @@ -0,0 +1,41 @@ +"""Define ONNX Sign operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Sign") +class Sign(handler.Handler): + """Implementation of the ONNX Sign operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) + + @classmethod + def version_9(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_9 Sign op.""" + cls._prepare(node, inputs, onnx_sign) + return onnx_sign + + @classmethod + def version_13(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_13 Sign op.""" + cls._prepare(node, inputs, onnx_sign) + return onnx_sign + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_sign(*input_args): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Sign for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jnp.sign(data) diff --git a/jaxonnxruntime/onnx_ops/tan.py b/jaxonnxruntime/onnx_ops/tan.py new file mode 100644 index 0000000..f5f2e07 --- /dev/null +++ b/jaxonnxruntime/onnx_ops/tan.py @@ -0,0 +1,41 @@ +"""Define ONNX Tan operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Tan") +class Tan(handler.Handler): + """Implementation of the ONNX Tan operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) + + @classmethod + def version_7(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_7 Tan op.""" + cls._prepare(node, inputs, onnx_tan) + return onnx_tan + + @classmethod + def version_22(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_22 Tan op.""" + cls._prepare(node, inputs, onnx_tan) + return onnx_tan + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_tan(*input_args): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Tan for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jnp.tan(data) diff --git a/tests/onnx_ops_test.py b/tests/onnx_ops_test.py index 33e21b8..46779ba 100644 --- a/tests/onnx_ops_test.py +++ b/tests/onnx_ops_test.py @@ -95,6 +95,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_div_') include_patterns.append('test_dropout_') include_patterns.append('test_einsum_') +include_patterns.append('test_elu_') include_patterns.append('test_equal_') include_patterns.append('test_erf_') include_patterns.append('test_exp_') @@ -106,6 +107,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_globalaveragepool_') include_patterns.append('test_greater_') include_patterns.append('test_greaterorequal_') +include_patterns.append('test_hardsigmoid_') include_patterns.append('test_identity_') include_patterns.append('test_if_') include_patterns.append('test_leakyrelu_') @@ -137,7 +139,9 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_scatternd_') include_patterns.append('test_selu_') include_patterns.append('test_shape_') +include_patterns.append('test_shrink') include_patterns.append('test_sigmoid_') +include_patterns.append('test_sign_') include_patterns.append('test_sin_') include_patterns.append('test_sinh_') include_patterns.append('test_slice_') @@ -148,6 +152,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_squeeze_') include_patterns.append('test_sub_') include_patterns.append('test_sum_') +include_patterns.append('test_tan_') include_patterns.append('test_tanh_') include_patterns.append('test_tile_') include_patterns.append('test_top_k_')