Hello,
Thanks for your great work. I have a question about the behavior of torch2jax wrapped function on jax.vmap. I see it uses ffi.ffi_call("torch_call", outshapes, vmap_method="sequential"). Does it mean that the vmap call will be called sequentially in pytorch? If so, is there any way to improve it by using other vmap_method options?