-
Notifications
You must be signed in to change notification settings - Fork 0
Refactor trainer class #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 10 comments.
Comments suppressed due to low confidence (3)
QuantumGravPy/src/QuantumGrav/train.py:710
- The docstring parameter type doesn't match the actual type annotation. The parameter is typed as
pd.DataFramein the signature, but the docstring sayslist[Any]. Update the docstring to reflect the correct type.
"""Check the status of the model during training.
Args:
eval_data (list[Any]): The evaluation data from the training epoch.
QuantumGravPy/src/QuantumGrav/train.py:320
- The docstring is outdated. The parameters
criterion,apply_model,early_stopping,validator, andtesterare no longer function parameters—they are now extracted from theconfigdictionary. The docstring should only document theconfigparameter and describe what keys it should contain.
"""Initialize the trainer.
Args:
config (dict[str, Any]): The configuration dictionary.
criterion (Callable): The loss function to use.
apply_model (Callable | None, optional): A function to apply the model. Defaults to None.
early_stopping (Callable[[Collection[Any]], bool] | None, optional): A function for early stopping. Defaults to None.
validator (DefaultValidator | None, optional): A validator for model evaluation. Defaults to None.
tester (DefaultTester | None, optional): A tester for model evaluation. Defaults to None.
QuantumGravPy/src/QuantumGrav/train_ddp.py:110
- The docstring is outdated. The parameters
criterion,apply_model,early_stopping,validator, andtesterare no longer function parameters—they are now extracted from theconfigdictionary. The docstring should only document therankandconfigparameters.
"""Initialize the distributed data parallel (DDP) trainer.
Args:
rank (int): The rank of the current process.
config (dict[str, Any]): The configuration dictionary.
criterion (Callable): The loss function.
apply_model (Callable | None, optional): The function to apply the model. Defaults to None.
early_stopping (Callable[[list[dict[str, Any]]], bool] | None, optional): The early stopping function. Defaults to None.
validator (DefaultValidator | None, optional): The validator for model evaluation. Defaults to None.
tester (DefaultTester | None, optional): The tester for model testing. Defaults to None.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 11 comments.
Comments suppressed due to low confidence (1)
QuantumGravPy/src/QuantumGrav/train.py:877
- Missing validation: The docstring at line 863 states "Raises: ValueError: If the model is not initialized" but there's no check for
self.model is Nonebefore callingself.model.save(outpath)at line 877. This will result in an AttributeError instead of the documented ValueError. Since the model is now initialized in__init__, this might be intentional, but the docstring should be updated to reflect the actual behavior or the check should be added back.
def save_checkpoint(self, name_addition: str = ""):
"""Save model checkpoint.
Raises:
ValueError: If the model is not initialized.
ValueError: If the model configuration does not contain 'name'.
ValueError: If the training configuration does not contain 'checkpoint_path'.
"""
self.logger.info(
f"Saving checkpoint for model at epoch {self.epoch} to {self.checkpoint_path}"
)
outpath = self.checkpoint_path / f"model_{name_addition}.pt"
if outpath.exists() is False:
outpath.parent.mkdir(parents=True, exist_ok=True)
self.logger.debug(f"Created directory {outpath.parent} for checkpoint.")
self.latest_checkpoint = outpath
self.model.save(outpath)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mephistoteles-whatever
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I've got some small comments. The main question I had was about the determinism of numpy (non-)seeding in the trainer class.
|
added remarks, and added the possibility to use learning rate schedulers. that's useful to stop optimizers jumping around a minimum in later stages of the optimization while also allowing them to escape from bad minima and making larger steps at the start. |
TrainerandTrainerDDP