-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Labels
bugSomething isn't workingSomething isn't working
Description
I think mode="reflect" for padding_kwargs is incorrect:
import jax.numpy as jnp
import kernex
@kernex.kmap(
kernel_size=(3,),
padding=("same"),
relative=False,
padding_kwargs=dict(mode="reflect"),
)
def f(x):
return x
x = jnp.array([1, 2, 3, 4, 5])
y = f(x)
z = jnp.pad(x, 1, mode="reflect")
print("x: ", x)
print("y: ", y)
print("z: ", z)gives
x: [1 2 3 4 5]
y: [[3 1 2] # <-- the `3` is incorrect, should be `2`
[1 2 3]
[2 3 4]
[3 4 5]
[4 5 4]]
z: [2 1 2 3 4 5 4] # <-- here, the first element is `2`
The Kernex output reflects incorrectly: the first element is 3 instead of 2.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working