Replies: 2 comments 4 replies
-
|
It's here: Or am I misunderstanding your question? And I don't use batched matmul because I the code is intended for educational use: to be clear and easy to read. In an ideal world, the JAX/XLA compiler would transform the graph to use batched operations, but I haven't checked if it does so. |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
Thank you for your answer. For me the line you sent me is here https://github.com/vpj/jax_transformer/blob/521d67e9160a6362a18e68e6b3aeafc270d40ad0/transformer.py#L741 |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, @awf
Sorry if my question is naive, but I'm not sure where to find this in your implementation:
https://github.com/vpj/jax_transformer/blob/521d67e9160a6362a18e68e6b3aeafc270d40ad0/transformer.py#L588
Also why don't you use batched matmul instead of looping over heads
Thanks !
Beta Was this translation helpful? Give feedback.
All reactions