Skip to content

Commit ec32f32

Browse files
committed
make initial guess also vectorized
1 parent 469174b commit ec32f32

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

driftbench/data_generation/solvers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, f, w0, max_fit_attemps, vectorize=True):
5959
self.min_func = jit(
6060
vmap(
6161
partial(_minimize),
62-
in_axes=(None, None, None, None, 0, 0, 0, 0, 0, 0),
62+
in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, 0),
6363
),
6464
static_argnums=(0, 1, 2),
6565
)
@@ -98,7 +98,7 @@ def _solve_sequentially(self, X, callback=None):
9898
return jnp.array(coefficients)
9999

100100
def _solve_vectorized(self, X):
101-
solution = self.w0
101+
solution = jnp.tile(self.w0, (len(X), 1))
102102
l_x0_mat, l_x1_mat, l_x2_mat, l_y0_mat, l_y1_mat, l_y2_mat = (
103103
self._latents_to_array(X)
104104
)

0 commit comments

Comments
 (0)