diff --git a/linx/reactions.py b/linx/reactions.py index ce2102e..8cc585a 100644 --- a/linx/reactions.py +++ b/linx/reactions.py @@ -144,7 +144,7 @@ def __init__( gpus = jax.devices('gpu') self.T9_vec = jax.device_put(self.T9_vec, device=gpus[0]) self.mu_median_vec = jax.device_put( - self.mu_media_vec, device=gpus[0] + self.mu_median_vec, device=gpus[0] ) self.expsigma_vec = jax.device_put( self.expsigma_vec, device=gpus[0]