Skip to content

Support Pytrees #11

@clemisch

Description

@clemisch

Does kernex support Pytrees? I did not find an example. It would be very useful to support moving-window filters with "global" weights or simply multiple inputs, such as a cross-channel bilateral filter in my case.

Repro:

import jax.numpy as jnp
import kernex

@kernex.kmap(kernel_size=(3, 3))
def kernel(tree):
    x, y = tree
    return jnp.sum(x * jnp.square(y))

data = jnp.arange(20 * 30).reshape((20, 30))
out = kernel((data, data))

raises

Traceback (most recent call last):
  File "/home/clemisch/kernex_tree.py", line 52, in <module>
    out = kernel((data, data))
          ^^^^^^^^^^^^^^^^^^^^
  File "/home/clemisch/venvs/11/lib64/python3.11/site-packages/kernex/interface/kernel_interface.py", line 131, in call
    self.shape = array.shape
                 ^^^^^^^^^^^
AttributeError: 'tuple' object has no attribute 'shape'

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions