Compute hidden state updates in StatefulOnnxLabelScorer in batches#181
Compute hidden state updates in StatefulOnnxLabelScorer in batches#181
StatefulOnnxLabelScorer in batches#181Conversation
|
Is it maybe possible to offer both this batched version and a non-batched version if the state updater does not have a batch dimension? |
In principle this is possible but the only way to check for a batch dimension is looking up whether dim 0 is dynamic which might be unreliable and lead to edge cases. |
Ideally, I would like to have a check if a "real" batch dimension exists and if yes use the batched version and if not automatically fall back to the unbatched version (maybe with a warning). But I see that this is not straightforward to realize... |
The most expensive part in search with an AED or LSTM LM is the hidden-state update which is currently done one-by-one for each hypothesis individually. This PR changes the logic to collect the unique hidden states first and then forward the state updater in a batched manner. Depending on how the state updater exports are structured, this may break older ONNX models since now the state updater needs to accept and return batched tensors (with batch axis 0).
So far not tested yet.