Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jaxonnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from jaxonnxruntime.onnx_ops import globalaveragepool
from jaxonnxruntime.onnx_ops import greater
from jaxonnxruntime.onnx_ops import greaterorequal
from jaxonnxruntime.onnx_ops import hardsigmoid
from jaxonnxruntime.onnx_ops import identity
from jaxonnxruntime.onnx_ops import if_op
from jaxonnxruntime.onnx_ops import leakyrelu
Expand Down Expand Up @@ -90,6 +91,7 @@
from jaxonnxruntime.onnx_ops import reducesum
from jaxonnxruntime.onnx_ops import relu
from jaxonnxruntime.onnx_ops import reshape
from jaxonnxruntime.onnx_ops import resize
from jaxonnxruntime.onnx_ops import scatterelements
from jaxonnxruntime.onnx_ops import scatternd
from jaxonnxruntime.onnx_ops import selu
Expand Down
33 changes: 23 additions & 10 deletions jaxonnxruntime/core/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,28 @@ def handle(
Returns:
The jax function.
"""
ver_handle = getattr(cls, "version_{}".format(cls.SINCE_VERSION), None)
if ver_handle:
return ver_handle(node, inputs, **kwargs) # pylint: disable=not-callable

# Get all the methods that start with "version_"
class_methods = inspect.getmembers(cls, predicate=inspect.ismethod)
version_methods = [
method_name
for method_name, _ in class_methods
version_methods = {
int(method_name.split("_")[1]): method
for method_name, method in class_methods
if method_name.startswith("version_")
]
}

if not version_methods:
raise NotImplementedError(
f"{node.op_type} has no versioned implementations."
)

# Find the largest version which is <= cls.SINCE_VERSION
available_versions = sorted(version_methods.keys(), reverse=True)
for v in available_versions:
if v <= cls.SINCE_VERSION:
return version_methods[v](node, inputs, **kwargs)

raise NotImplementedError(
f"{node.op_type} version {cls.SINCE_VERSION} is not implemented."
f" Only have those versions: {version_methods}."
f" Only have those versions: {sorted(version_methods.keys())}."
)

@classmethod
Expand All @@ -104,7 +111,13 @@ def _prepare(
inputs: Sequence[Any],
onnx_jax_impl: Callable[..., Any],
) -> None:
"""Rwrite the OnnxNode to prepare the inputs attributes for the onnx jax implementation."""
"""Rwrite the OnnxNode to prepare the inputs attributes.

Args:
node: The onnx node to be handled.
inputs: The inputs of the onnx node.
onnx_jax_impl: The jax function for the onnx op.
"""
raise NotImplementedError


Expand Down
30 changes: 25 additions & 5 deletions jaxonnxruntime/onnx_ops/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,22 @@ class Clip(handler.Handler):
@classmethod
def _prepare_6(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any
):
) -> None:
node.attrs_dict['amin'] = node.attrs.get('min')
node.attrs_dict['amax'] = node.attrs.get('max')

@classmethod
def _prepare_13(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any
):
) -> None:
pass

@classmethod
def _prepare(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any
) -> None:
cls._prepare_13(node, inputs, onnx_jax_impl)

@classmethod
def version_6(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
Expand All @@ -50,6 +56,14 @@ def version_6(
cls._prepare_6(node, inputs, onnx_clip)
return onnx_clip

@classmethod
def version_11(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
) -> Callable[..., Any]:
"""ONNX version_11 Clip op."""
cls._prepare_13(node, inputs, onnx_clip)
return onnx_clip

@classmethod
def version_13(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
Expand All @@ -60,8 +74,14 @@ def version_13(


@functools.partial(jax.jit, static_argnames=())
def onnx_clip(data, amin=None, amax=None):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Clip for more details."""
def onnx_clip(*input_args: jax.Array, amin=None, amax=None) -> jax.Array:
"""The impl for Clip operator."""
data = input_args[0]
if len(input_args) >= 2:
amin = input_args[1]
if len(input_args) >= 3:
amax = input_args[2]

if amin is None and amax is None:
return data
return jnp.clip(data, min=amin, max=amax)
return jnp.clip(data, a_min=amin, a_max=amax)
64 changes: 64 additions & 0 deletions jaxonnxruntime/onnx_ops/hardsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2025 The Jaxonnxruntime Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Define ONNX HardSigmoid operator."""
from collections.abc import Callable, Sequence
import functools
import inspect
from typing import Any

import jax
from jax import jit
from jax import numpy as jnp
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node


@handler.register_op("HardSigmoid")
class HardSigmoid(handler.Handler):
"""Implementation of the ONNX HardSigmoid operator."""

@classmethod
def _prepare(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any
) -> None:
sig = inspect.signature(onnx_jax_impl)
kwparams = [
param.name
for param in sig.parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
for name in kwparams:
node.attrs_dict[name] = node.attrs.get(name, None)

@classmethod
def version_6(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
) -> Callable[..., Any]:
"""ONNX version_6 HardSigmoid op."""
cls._prepare(node, inputs, onnx_hardsigmoid)
return onnx_hardsigmoid


@functools.partial(jit, static_argnames=("alpha", "beta"))
def onnx_hardsigmoid(
*input_args: jax.Array, alpha: float = 0.2, beta: float = 0.5
) -> jax.Array:
"""The internal jax impl for onnx HardSigmoid op."""
assert len(input_args) == 1
x = input_args[0]
if alpha is None:
alpha = 0.2
if beta is None:
beta = 0.5
return jnp.clip(alpha * x + beta, 0.0, 1.0)
10 changes: 9 additions & 1 deletion jaxonnxruntime/onnx_ops/reducesum.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ def version_1(
cls._prepare(node, inputs, onnx_reducesum)
return onnx_reducesum

@classmethod
def version_11(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
) -> Callable[..., Any]:
"""ONNX version_11 ReduceSum op."""
cls._prepare(node, inputs, onnx_reducesum)
return onnx_reducesum

@classmethod
def version_13(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
Expand All @@ -81,7 +89,7 @@ def onnx_reducesum(
keepdims=1,
noop_with_empty_axes=0,
):
"""The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#ReduceSum."""
"""The impl for ReduceSum operator."""
assert len(input_args) == 1 or len(input_args) == 2
data = input_args[0]
if axes is None and noop_with_empty_axes > 0:
Expand Down
122 changes: 122 additions & 0 deletions jaxonnxruntime/onnx_ops/resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2025 The Jaxonnxruntime Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Define ONNX Resize operator."""
from collections.abc import Callable, Sequence
import functools
from typing import Any
import jax
from jax import numpy as jnp
from jaxonnxruntime.core import config_class

config = config_class.config
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node
import numpy as np


@handler.register_op('Resize')
class Resize(handler.Handler):
"""Implementation of the ONNX Resize operator."""

@classmethod
def _prepare(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any
):
# Extract attributes
# attrs are already strings in this jaxonnxruntime version
mode = node.attrs.get('mode', 'nearest')
if isinstance(mode, bytes):
mode = mode.decode('utf-8')
node.attrs_dict['mode'] = mode

coordinate_transformation_mode = node.attrs.get('coordinate_transformation_mode', 'half_pixel')
if isinstance(coordinate_transformation_mode, bytes):
coordinate_transformation_mode = coordinate_transformation_mode.decode('utf-8')
node.attrs_dict['coordinate_transformation_mode'] = coordinate_transformation_mode

# Resize has inputs: X, roi, scales, sizes
# We need to determine target 'sizes'
if len(inputs) >= 4 and inputs[3] is not None:
node.attrs_dict['sizes'] = tuple(inputs[3].astype(int).tolist())
elif len(inputs) >= 3 and inputs[2] is not None:
scales = inputs[2]
x_shape = inputs[0].shape
sizes = [int(x_shape[i] * scales[i]) for i in range(len(x_shape))]
node.attrs_dict['sizes'] = tuple(sizes)
else:
# Fallback: check if sizes/scales are in constant dict
constant_dict = node.context_graph.get_constant_dict()
if len(node.inputs) >= 4 and node.inputs[3] in constant_dict:
node.attrs_dict['sizes'] = tuple(constant_dict[node.inputs[3]].astype(int).tolist())
elif len(node.inputs) >= 3 and node.inputs[2] in constant_dict:
scales = constant_dict[node.inputs[2]]
x_shape = inputs[0].shape
sizes = [int(x_shape[i] * scales[i]) for i in range(len(x_shape))]
node.attrs_dict['sizes'] = tuple(sizes)
else:
# If still not found, we might have to wait for runtime shape or it's a dynamic resize.
# However, for JAX, we prefer static sizes.
raise ValueError(f"Resize node {node.name} needs valid 'scales' or 'sizes'.")

@classmethod
def version_10(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
) -> Callable[..., Any]:
"""ONNX version_10 Resize op."""
cls._prepare(node, inputs, onnx_resize)
return onnx_resize

@classmethod
def version_11(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
) -> Callable[..., Any]:
"""ONNX version_11 Resize op."""
cls._prepare(node, inputs, onnx_resize)
return onnx_resize

@classmethod
def version_13(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
) -> Callable[..., Any]:
"""ONNX version_13 Resize op."""
cls._prepare(node, inputs, onnx_resize)
return onnx_resize

@classmethod
def version_18(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
) -> Callable[..., Any]:
"""ONNX version_18 Resize op."""
cls._prepare(node, inputs, onnx_resize)
return onnx_resize


@functools.partial(jax.jit, static_argnames=('sizes', 'mode', 'coordinate_transformation_mode'))
def onnx_resize(*input_args, sizes, mode, coordinate_transformation_mode):
"""The impl for Resize."""
x = input_args[0]

# Map ONNX modes to JAX modes
# ONNX modes: 'nearest', 'linear', 'cubic'
# JAX modes: 'nearest', 'linear', 'bilinear', 'trilinear', 'cubic', 'lanczos3', 'lanczos5'
jax_method = mode
if jax_method == 'linear':
# If rank is 4 (N, C, H, W), 'bilinear' is more appropriate for 2D spatial resize
if x.ndim == 4:
jax_method = 'bilinear'
else:
jax_method = 'linear'

return jax.image.resize(x, shape=sizes, method=jax_method)
4 changes: 2 additions & 2 deletions jaxonnxruntime/onnx_ops/squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _prepare(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any
):
if len(inputs) == 1:
node.attrs_dict['axis'] = None
node.attrs_dict['axis'] = node.attrs.get('axes', None)
else:
node.attrs_dict['axis'] = tuple(inputs[1].tolist())

Expand Down Expand Up @@ -76,6 +76,6 @@ def version_13(

@functools.partial(jax.jit, static_argnames='axis')
def onnx_squeeze(*input_args, axis):
"""The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Squeeze."""
"""The impl for Squeeze operator."""
x = input_args[0]
return jnp.squeeze(x, axis=axis)
11 changes: 11 additions & 0 deletions tests/onnx_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class NodeTest(absltest.TestCase):
include_patterns.append('test_globalaveragepool_')
include_patterns.append('test_greater_')
include_patterns.append('test_greaterorequal_')
include_patterns.append('test_hardsigmoid_')
include_patterns.append('test_identity_')
include_patterns.append('test_if_')
include_patterns.append('test_leakyrelu_')
Expand Down Expand Up @@ -133,6 +134,7 @@ class NodeTest(absltest.TestCase):
include_patterns.append('test_reduce_sum_')
include_patterns.append('test_relu_')
include_patterns.append('test_reshape_')
include_patterns.append('test_resize_')
include_patterns.append('test_scatter_elements_')
include_patterns.append('test_scatternd_')
include_patterns.append('test_selu_')
Expand Down Expand Up @@ -251,6 +253,15 @@ class NodeTest(absltest.TestCase):
'test_training_dropout_default_',
'test_training_dropout_default_mask_',
'test_training_dropout_mask_',
# Resize: Needs full spec implementation for coordinate transformations and rounding
'test_resize_downsample_',
'test_resize_tf_crop_',
'test_resize_upsample_scales_cubic_align_corners_',
'test_resize_upsample_scales_cubic_asymmetric_',
'test_resize_upsample_scales_cubic_gpu',
'test_resize_upsample_scales_linear_align_corners_',
'test_resize_upsample_scales_linear_half_pixel_symmetric_',
'test_resize_upsample_sizes_',
])


Expand Down