diff --git a/jaxonnxruntime/__init__.py b/jaxonnxruntime/__init__.py index fe5349f..438a393 100644 --- a/jaxonnxruntime/__init__.py +++ b/jaxonnxruntime/__init__.py @@ -61,6 +61,7 @@ from jaxonnxruntime.onnx_ops import globalaveragepool from jaxonnxruntime.onnx_ops import greater from jaxonnxruntime.onnx_ops import greaterorequal +from jaxonnxruntime.onnx_ops import hardsigmoid from jaxonnxruntime.onnx_ops import identity from jaxonnxruntime.onnx_ops import if_op from jaxonnxruntime.onnx_ops import leakyrelu @@ -90,6 +91,7 @@ from jaxonnxruntime.onnx_ops import reducesum from jaxonnxruntime.onnx_ops import relu from jaxonnxruntime.onnx_ops import reshape +from jaxonnxruntime.onnx_ops import resize from jaxonnxruntime.onnx_ops import scatterelements from jaxonnxruntime.onnx_ops import scatternd from jaxonnxruntime.onnx_ops import selu diff --git a/jaxonnxruntime/core/handler.py b/jaxonnxruntime/core/handler.py index 161e4d0..4f7374b 100644 --- a/jaxonnxruntime/core/handler.py +++ b/jaxonnxruntime/core/handler.py @@ -80,21 +80,28 @@ def handle( Returns: The jax function. """ - ver_handle = getattr(cls, "version_{}".format(cls.SINCE_VERSION), None) - if ver_handle: - return ver_handle(node, inputs, **kwargs) # pylint: disable=not-callable - # Get all the methods that start with "version_" class_methods = inspect.getmembers(cls, predicate=inspect.ismethod) - version_methods = [ - method_name - for method_name, _ in class_methods + version_methods = { + int(method_name.split("_")[1]): method + for method_name, method in class_methods if method_name.startswith("version_") - ] + } + + if not version_methods: + raise NotImplementedError( + f"{node.op_type} has no versioned implementations." + ) + + # Find the largest version which is <= cls.SINCE_VERSION + available_versions = sorted(version_methods.keys(), reverse=True) + for v in available_versions: + if v <= cls.SINCE_VERSION: + return version_methods[v](node, inputs, **kwargs) raise NotImplementedError( f"{node.op_type} version {cls.SINCE_VERSION} is not implemented." - f" Only have those versions: {version_methods}." + f" Only have those versions: {sorted(version_methods.keys())}." ) @classmethod @@ -104,7 +111,13 @@ def _prepare( inputs: Sequence[Any], onnx_jax_impl: Callable[..., Any], ) -> None: - """Rwrite the OnnxNode to prepare the inputs attributes for the onnx jax implementation.""" + """Rwrite the OnnxNode to prepare the inputs attributes. + + Args: + node: The onnx node to be handled. + inputs: The inputs of the onnx node. + onnx_jax_impl: The jax function for the onnx op. + """ raise NotImplementedError diff --git a/jaxonnxruntime/onnx_ops/clip.py b/jaxonnxruntime/onnx_ops/clip.py index 2aed467..1c98070 100644 --- a/jaxonnxruntime/onnx_ops/clip.py +++ b/jaxonnxruntime/onnx_ops/clip.py @@ -32,16 +32,22 @@ class Clip(handler.Handler): @classmethod def _prepare_6( cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any - ): + ) -> None: node.attrs_dict['amin'] = node.attrs.get('min') node.attrs_dict['amax'] = node.attrs.get('max') @classmethod def _prepare_13( cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any - ): + ) -> None: pass + @classmethod + def _prepare( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any + ) -> None: + cls._prepare_13(node, inputs, onnx_jax_impl) + @classmethod def version_6( cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] @@ -50,6 +56,14 @@ def version_6( cls._prepare_6(node, inputs, onnx_clip) return onnx_clip + @classmethod + def version_11( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] + ) -> Callable[..., Any]: + """ONNX version_11 Clip op.""" + cls._prepare_13(node, inputs, onnx_clip) + return onnx_clip + @classmethod def version_13( cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] @@ -60,8 +74,14 @@ def version_13( @functools.partial(jax.jit, static_argnames=()) -def onnx_clip(data, amin=None, amax=None): - """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Clip for more details.""" +def onnx_clip(*input_args: jax.Array, amin=None, amax=None) -> jax.Array: + """The impl for Clip operator.""" + data = input_args[0] + if len(input_args) >= 2: + amin = input_args[1] + if len(input_args) >= 3: + amax = input_args[2] + if amin is None and amax is None: return data - return jnp.clip(data, min=amin, max=amax) + return jnp.clip(data, a_min=amin, a_max=amax) diff --git a/jaxonnxruntime/onnx_ops/hardsigmoid.py b/jaxonnxruntime/onnx_ops/hardsigmoid.py new file mode 100644 index 0000000..092c360 --- /dev/null +++ b/jaxonnxruntime/onnx_ops/hardsigmoid.py @@ -0,0 +1,64 @@ +# Copyright 2025 The Jaxonnxruntime Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Define ONNX HardSigmoid operator.""" +from collections.abc import Callable, Sequence +import functools +import inspect +from typing import Any + +import jax +from jax import jit +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node + + +@handler.register_op("HardSigmoid") +class HardSigmoid(handler.Handler): + """Implementation of the ONNX HardSigmoid operator.""" + + @classmethod + def _prepare( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any + ) -> None: + sig = inspect.signature(onnx_jax_impl) + kwparams = [ + param.name + for param in sig.parameters.values() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] + for name in kwparams: + node.attrs_dict[name] = node.attrs.get(name, None) + + @classmethod + def version_6( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] + ) -> Callable[..., Any]: + """ONNX version_6 HardSigmoid op.""" + cls._prepare(node, inputs, onnx_hardsigmoid) + return onnx_hardsigmoid + + +@functools.partial(jit, static_argnames=("alpha", "beta")) +def onnx_hardsigmoid( + *input_args: jax.Array, alpha: float = 0.2, beta: float = 0.5 +) -> jax.Array: + """The internal jax impl for onnx HardSigmoid op.""" + assert len(input_args) == 1 + x = input_args[0] + if alpha is None: + alpha = 0.2 + if beta is None: + beta = 0.5 + return jnp.clip(alpha * x + beta, 0.0, 1.0) diff --git a/jaxonnxruntime/onnx_ops/reducesum.py b/jaxonnxruntime/onnx_ops/reducesum.py index 8338980..a79f83e 100644 --- a/jaxonnxruntime/onnx_ops/reducesum.py +++ b/jaxonnxruntime/onnx_ops/reducesum.py @@ -63,6 +63,14 @@ def version_1( cls._prepare(node, inputs, onnx_reducesum) return onnx_reducesum + @classmethod + def version_11( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] + ) -> Callable[..., Any]: + """ONNX version_11 ReduceSum op.""" + cls._prepare(node, inputs, onnx_reducesum) + return onnx_reducesum + @classmethod def version_13( cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] @@ -81,7 +89,7 @@ def onnx_reducesum( keepdims=1, noop_with_empty_axes=0, ): - """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#ReduceSum.""" + """The impl for ReduceSum operator.""" assert len(input_args) == 1 or len(input_args) == 2 data = input_args[0] if axes is None and noop_with_empty_axes > 0: diff --git a/jaxonnxruntime/onnx_ops/resize.py b/jaxonnxruntime/onnx_ops/resize.py new file mode 100644 index 0000000..ddf698f --- /dev/null +++ b/jaxonnxruntime/onnx_ops/resize.py @@ -0,0 +1,122 @@ +# Copyright 2025 The Jaxonnxruntime Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Define ONNX Resize operator.""" +from collections.abc import Callable, Sequence +import functools +from typing import Any +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import config_class + +config = config_class.config +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +import numpy as np + + +@handler.register_op('Resize') +class Resize(handler.Handler): + """Implementation of the ONNX Resize operator.""" + + @classmethod + def _prepare( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any + ): + # Extract attributes + # attrs are already strings in this jaxonnxruntime version + mode = node.attrs.get('mode', 'nearest') + if isinstance(mode, bytes): + mode = mode.decode('utf-8') + node.attrs_dict['mode'] = mode + + coordinate_transformation_mode = node.attrs.get('coordinate_transformation_mode', 'half_pixel') + if isinstance(coordinate_transformation_mode, bytes): + coordinate_transformation_mode = coordinate_transformation_mode.decode('utf-8') + node.attrs_dict['coordinate_transformation_mode'] = coordinate_transformation_mode + + # Resize has inputs: X, roi, scales, sizes + # We need to determine target 'sizes' + if len(inputs) >= 4 and inputs[3] is not None: + node.attrs_dict['sizes'] = tuple(inputs[3].astype(int).tolist()) + elif len(inputs) >= 3 and inputs[2] is not None: + scales = inputs[2] + x_shape = inputs[0].shape + sizes = [int(x_shape[i] * scales[i]) for i in range(len(x_shape))] + node.attrs_dict['sizes'] = tuple(sizes) + else: + # Fallback: check if sizes/scales are in constant dict + constant_dict = node.context_graph.get_constant_dict() + if len(node.inputs) >= 4 and node.inputs[3] in constant_dict: + node.attrs_dict['sizes'] = tuple(constant_dict[node.inputs[3]].astype(int).tolist()) + elif len(node.inputs) >= 3 and node.inputs[2] in constant_dict: + scales = constant_dict[node.inputs[2]] + x_shape = inputs[0].shape + sizes = [int(x_shape[i] * scales[i]) for i in range(len(x_shape))] + node.attrs_dict['sizes'] = tuple(sizes) + else: + # If still not found, we might have to wait for runtime shape or it's a dynamic resize. + # However, for JAX, we prefer static sizes. + raise ValueError(f"Resize node {node.name} needs valid 'scales' or 'sizes'.") + + @classmethod + def version_10( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] + ) -> Callable[..., Any]: + """ONNX version_10 Resize op.""" + cls._prepare(node, inputs, onnx_resize) + return onnx_resize + + @classmethod + def version_11( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] + ) -> Callable[..., Any]: + """ONNX version_11 Resize op.""" + cls._prepare(node, inputs, onnx_resize) + return onnx_resize + + @classmethod + def version_13( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] + ) -> Callable[..., Any]: + """ONNX version_13 Resize op.""" + cls._prepare(node, inputs, onnx_resize) + return onnx_resize + + @classmethod + def version_18( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] + ) -> Callable[..., Any]: + """ONNX version_18 Resize op.""" + cls._prepare(node, inputs, onnx_resize) + return onnx_resize + + +@functools.partial(jax.jit, static_argnames=('sizes', 'mode', 'coordinate_transformation_mode')) +def onnx_resize(*input_args, sizes, mode, coordinate_transformation_mode): + """The impl for Resize.""" + x = input_args[0] + + # Map ONNX modes to JAX modes + # ONNX modes: 'nearest', 'linear', 'cubic' + # JAX modes: 'nearest', 'linear', 'bilinear', 'trilinear', 'cubic', 'lanczos3', 'lanczos5' + jax_method = mode + if jax_method == 'linear': + # If rank is 4 (N, C, H, W), 'bilinear' is more appropriate for 2D spatial resize + if x.ndim == 4: + jax_method = 'bilinear' + else: + jax_method = 'linear' + + return jax.image.resize(x, shape=sizes, method=jax_method) diff --git a/jaxonnxruntime/onnx_ops/squeeze.py b/jaxonnxruntime/onnx_ops/squeeze.py index a08f9e0..a817435 100644 --- a/jaxonnxruntime/onnx_ops/squeeze.py +++ b/jaxonnxruntime/onnx_ops/squeeze.py @@ -45,7 +45,7 @@ def _prepare( cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any ): if len(inputs) == 1: - node.attrs_dict['axis'] = None + node.attrs_dict['axis'] = node.attrs.get('axes', None) else: node.attrs_dict['axis'] = tuple(inputs[1].tolist()) @@ -76,6 +76,6 @@ def version_13( @functools.partial(jax.jit, static_argnames='axis') def onnx_squeeze(*input_args, axis): - """The impl for https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Squeeze.""" + """The impl for Squeeze operator.""" x = input_args[0] return jnp.squeeze(x, axis=axis) diff --git a/tests/onnx_ops_test.py b/tests/onnx_ops_test.py index 33e21b8..0edf2c7 100644 --- a/tests/onnx_ops_test.py +++ b/tests/onnx_ops_test.py @@ -106,6 +106,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_globalaveragepool_') include_patterns.append('test_greater_') include_patterns.append('test_greaterorequal_') +include_patterns.append('test_hardsigmoid_') include_patterns.append('test_identity_') include_patterns.append('test_if_') include_patterns.append('test_leakyrelu_') @@ -133,6 +134,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_reduce_sum_') include_patterns.append('test_relu_') include_patterns.append('test_reshape_') +include_patterns.append('test_resize_') include_patterns.append('test_scatter_elements_') include_patterns.append('test_scatternd_') include_patterns.append('test_selu_') @@ -251,6 +253,15 @@ class NodeTest(absltest.TestCase): 'test_training_dropout_default_', 'test_training_dropout_default_mask_', 'test_training_dropout_mask_', + # Resize: Needs full spec implementation for coordinate transformations and rounding + 'test_resize_downsample_', + 'test_resize_tf_crop_', + 'test_resize_upsample_scales_cubic_align_corners_', + 'test_resize_upsample_scales_cubic_asymmetric_', + 'test_resize_upsample_scales_cubic_gpu', + 'test_resize_upsample_scales_linear_align_corners_', + 'test_resize_upsample_scales_linear_half_pixel_symmetric_', + 'test_resize_upsample_sizes_', ])