-
Notifications
You must be signed in to change notification settings - Fork 2
Horus: added a pre-trained model - ResNet and checkpoint support. #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
jereml99
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Traning idzie całkiem powolutku mogę obczaić co tam tyle zajmuję
Widzę że są jakieś konflikty z mergem,
| self.feature_extractor = resnet50(weights=ResNet50_Weights.DEFAULT) | ||
| # Replace the last fully connected layer | ||
| # Parameters of newly constructed modules have requires_grad=True by default | ||
| num_ftrs = self.feature_extractor.fc.in_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A gdzie podajemy liczbę cech na wejście sieci?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Powinniśmy mieć automatycznego lintera xd
| def __getitem__(self, idx): | ||
| mel_spec, label = self.data[idx] | ||
| mel_spec = torch.from_numpy(mel_spec).unsqueeze(0) # Add channel dimension | ||
| mel_spec = mel_spec.expand(3, -1, -1) # Expand single channel to three channels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
W sumie dlaczego?
| filename=f"{split_name}_bird_classifier_batch_{batch_size}_lr_{learning_rate}", | ||
| monitor="val_loss", # The metric to monitor | ||
| save_top_k=1, # Save only the top 1 models based on the metric monitored | ||
| mode="min", # In 'min' mode, training will stop when the quantity monitored has stopped decreasing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| mode="min", # In 'min' mode, training will stop when the quantity monitored has stopped decreasing | |
| mode="min", # In 'min' mode, the model with lowest val_loss is picked |
| # Save the trained model (The best model is saved by the checkpoint_callback) | ||
| # trainer.save_checkpoint(f"{split_name}_bird_classifier_batch_{batch_size}.ckpt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do wyrzucenia może być
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fajne
|
|
||
| class BirdSpectrogramDataModule(pl.LightningDataModule): | ||
| def __init__(self, root_dir, batch_size=16): | ||
| def __init__(self, root_dir, batch_size=16, num_workers=4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
jereml99
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Traning idzie całkiem powolutku mogę obczaić co tam tyle zajmuję
Widzę że są jakieś konflikty z mergem,
No description provided.