Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions jaxonnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from jaxonnxruntime.onnx_ops import div
from jaxonnxruntime.onnx_ops import dropout
from jaxonnxruntime.onnx_ops import einsum
from jaxonnxruntime.onnx_ops import elu
from jaxonnxruntime.onnx_ops import equal
from jaxonnxruntime.onnx_ops import erf
from jaxonnxruntime.onnx_ops import exp
Expand All @@ -61,6 +62,7 @@
from jaxonnxruntime.onnx_ops import globalaveragepool
from jaxonnxruntime.onnx_ops import greater
from jaxonnxruntime.onnx_ops import greaterorequal
from jaxonnxruntime.onnx_ops import hardsigmoid
from jaxonnxruntime.onnx_ops import identity
from jaxonnxruntime.onnx_ops import if_op
from jaxonnxruntime.onnx_ops import leakyrelu
Expand Down Expand Up @@ -94,7 +96,9 @@
from jaxonnxruntime.onnx_ops import scatternd
from jaxonnxruntime.onnx_ops import selu
from jaxonnxruntime.onnx_ops import shape
from jaxonnxruntime.onnx_ops import shrink
from jaxonnxruntime.onnx_ops import sigmoid
from jaxonnxruntime.onnx_ops import sign
from jaxonnxruntime.onnx_ops import sin
from jaxonnxruntime.onnx_ops import sinh
from jaxonnxruntime.onnx_ops import slice
Expand All @@ -105,6 +109,7 @@
from jaxonnxruntime.onnx_ops import squeeze
from jaxonnxruntime.onnx_ops import sub
from jaxonnxruntime.onnx_ops import sum
from jaxonnxruntime.onnx_ops import tan
from jaxonnxruntime.onnx_ops import tanh
from jaxonnxruntime.onnx_ops import tile
from jaxonnxruntime.onnx_ops import topk
Expand Down
48 changes: 48 additions & 0 deletions jaxonnxruntime/onnx_ops/elu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Define ONNX Elu operator."""
# pylint: disable=unused-argument
# pylint: disable=g-explicit-length-test
import functools
from collections.abc import Callable, Sequence
from typing import Any

import jax
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node
from jaxonnxruntime.onnx_ops import onnx_ops_utils


@handler.register_op("Elu")
class Elu(handler.Handler):
"""Implementation of the ONNX Elu operator."""

@classmethod
def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any):
node.attrs_dict['alpha'] = node.attrs.get(
'alpha', 1.0
)

@classmethod
def version_1(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_1 Elu op."""
cls._prepare(node, inputs, onnx_elu)
return onnx_elu

@classmethod
def version_6(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_6 Elu op."""
cls._prepare(node, inputs, onnx_elu)
return onnx_elu

@classmethod
def version_22(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_22 Elu op."""
cls._prepare(node, inputs, onnx_elu)
return onnx_elu


@functools.partial(jax.jit, static_argnames=())
def onnx_elu(*input_args, alpha):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Elu for more details."""
assert len(input_args) == 1
data = input_args[0]
return jax.nn.elu(data, alpha)
48 changes: 48 additions & 0 deletions jaxonnxruntime/onnx_ops/hardsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Define ONNX HardSigmoid operator."""
# pylint: disable=unused-argument
# pylint: disable=g-explicit-length-test
import functools
from collections.abc import Callable, Sequence
from typing import Any

import jax
from jax import numpy as jnp
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node
from jaxonnxruntime.onnx_ops import onnx_ops_utils


@handler.register_op("HardSigmoid")
class HardSigmoid(handler.Handler):
"""Implementation of the ONNX HardSigmoid operator."""

@classmethod
def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any):
node.attrs_dict["alpha"] = node.attrs.get("alpha", 0.2)
node.attrs_dict["beta"] = node.attrs.get("beta", 0.5)

@classmethod
def version_1(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_1 HardSigmoid op."""
cls._prepare(node, inputs, onnx_hardsigmoid)
return onnx_hardsigmoid

@classmethod
def version_6(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_6 HardSigmoid op."""
cls._prepare(node, inputs, onnx_hardsigmoid)
return onnx_hardsigmoid

@classmethod
def version_22(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_22 HardSigmoid op."""
cls._prepare(node, inputs, onnx_hardsigmoid)
return onnx_hardsigmoid


@functools.partial(jax.jit, static_argnames=())
def onnx_hardsigmoid(*input_args, alpha, beta):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#HardSigmoid for more details."""
assert len(input_args) == 1
data = input_args[0]
return jnp.maximum(0, jnp.minimum(1, data * alpha + beta)).astype(data.dtype)
40 changes: 40 additions & 0 deletions jaxonnxruntime/onnx_ops/shrink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Define ONNX Shrink operator."""
# pylint: disable=unused-argument
# pylint: disable=g-explicit-length-test
import functools
from collections.abc import Callable, Sequence
from typing import Any

import jax
from jax import numpy as jnp
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node
from jaxonnxruntime.onnx_ops import onnx_ops_utils


@handler.register_op("Shrink")
class Shrink(handler.Handler):
"""Implementation of the ONNX Shrink operator."""

@classmethod
def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any):
node.attrs_dict['bias'] = node.attrs.get('bias', 0.0)
node.attrs_dict['lambd'] = node.attrs.get('lambd', 0.5)

@classmethod
def version_9(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_9 Shrink op."""
cls._prepare(node, inputs, onnx_shrink)
return onnx_shrink


@functools.partial(jax.jit, static_argnames=())
def onnx_shrink(*input_args, bias, lambd):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Shrink for more details."""
assert len(input_args) == 1
data = input_args[0]
return jnp.where(
data < -lambd,
data + bias,
jnp.where(data > lambd, data - bias, 0),
).astype(data.dtype)
41 changes: 41 additions & 0 deletions jaxonnxruntime/onnx_ops/sign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Define ONNX Sign operator."""
# pylint: disable=unused-argument
# pylint: disable=g-explicit-length-test
import functools
from collections.abc import Callable, Sequence
from typing import Any

import jax
from jax import numpy as jnp
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node
from jaxonnxruntime.onnx_ops import onnx_ops_utils


@handler.register_op("Sign")
class Sign(handler.Handler):
"""Implementation of the ONNX Sign operator."""

@classmethod
def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any):
onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl)

@classmethod
def version_9(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_9 Sign op."""
cls._prepare(node, inputs, onnx_sign)
return onnx_sign

@classmethod
def version_13(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_13 Sign op."""
cls._prepare(node, inputs, onnx_sign)
return onnx_sign


@functools.partial(jax.jit, static_argnames=())
def onnx_sign(*input_args):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Sign for more details."""
assert len(input_args) == 1
data = input_args[0]
return jnp.sign(data)
41 changes: 41 additions & 0 deletions jaxonnxruntime/onnx_ops/tan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Define ONNX Tan operator."""
# pylint: disable=unused-argument
# pylint: disable=g-explicit-length-test
import functools
from collections.abc import Callable, Sequence
from typing import Any

import jax
from jax import numpy as jnp
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node
from jaxonnxruntime.onnx_ops import onnx_ops_utils


@handler.register_op("Tan")
class Tan(handler.Handler):
"""Implementation of the ONNX Tan operator."""

@classmethod
def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any):
onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl)

@classmethod
def version_7(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_7 Tan op."""
cls._prepare(node, inputs, onnx_tan)
return onnx_tan

@classmethod
def version_22(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]:
"""ONNX version_22 Tan op."""
cls._prepare(node, inputs, onnx_tan)
return onnx_tan


@functools.partial(jax.jit, static_argnames=())
def onnx_tan(*input_args):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Tan for more details."""
assert len(input_args) == 1
data = input_args[0]
return jnp.tan(data)
5 changes: 5 additions & 0 deletions tests/onnx_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class NodeTest(absltest.TestCase):
include_patterns.append('test_div_')
include_patterns.append('test_dropout_')
include_patterns.append('test_einsum_')
include_patterns.append('test_elu_')
include_patterns.append('test_equal_')
include_patterns.append('test_erf_')
include_patterns.append('test_exp_')
Expand All @@ -106,6 +107,7 @@ class NodeTest(absltest.TestCase):
include_patterns.append('test_globalaveragepool_')
include_patterns.append('test_greater_')
include_patterns.append('test_greaterorequal_')
include_patterns.append('test_hardsigmoid_')
include_patterns.append('test_identity_')
include_patterns.append('test_if_')
include_patterns.append('test_leakyrelu_')
Expand Down Expand Up @@ -137,7 +139,9 @@ class NodeTest(absltest.TestCase):
include_patterns.append('test_scatternd_')
include_patterns.append('test_selu_')
include_patterns.append('test_shape_')
include_patterns.append('test_shrink')
include_patterns.append('test_sigmoid_')
include_patterns.append('test_sign_')
include_patterns.append('test_sin_')
include_patterns.append('test_sinh_')
include_patterns.append('test_slice_')
Expand All @@ -148,6 +152,7 @@ class NodeTest(absltest.TestCase):
include_patterns.append('test_squeeze_')
include_patterns.append('test_sub_')
include_patterns.append('test_sum_')
include_patterns.append('test_tan_')
include_patterns.append('test_tanh_')
include_patterns.append('test_tile_')
include_patterns.append('test_top_k_')
Expand Down