-
Notifications
You must be signed in to change notification settings - Fork 80
Train batch generic #724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Train batch generic #724
Conversation
felipemello1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont think that this class should be in trainer.py. Probably in types.py or something like that. Are you also going to add it to collate and test it in this PR?
Why wouldn't this be in the trainer.py file under api? It defines the training API of which this is part. I would vote to keep it in the trainer API. |
81e475d to
34af55b
Compare
this is also used collate_fn. Not sure if it may be used in other places. I think we would be exposed to circular dependencies. e.g. collate imports from train Also, thats what other frameworks do, like tinker: https://github.com/thinking-machines-lab/tinker/blob/ad03d44978096b1dcae662e469293e70f509d5a8/src/tinker/types/datum.py#L25 |
What would X be here? I will not hold up the PR on this point but am curious b/c I have a hard time imagining what that would be. |
I will leave that as an exercise for the reader jk, i guess it cannot happen if collate is its own file and doesnt really import from anywhere. It just makes more sense to me, given the patterns i have seen. But no big deal either way. Worst case we refactor later. |
Summary
Adds
TrainBatchdataclass that separatesmodel_inputsfromloss_inputs, enabling any training paradigm without type changes.Motivation
The current
TextTrainBatchhas limitations:Solution
Test Plan