Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions egomimic/pl_utils/pl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,23 @@ def on_validation_start(self):
self.model.device = self.device

if self.trainer.is_global_zero:
os.makedirs(
os.path.join(self.video_dir(), f"epoch_{self.trainer.current_epoch}"),
exist_ok=True,
)
os.makedirs(os.path.join(self.video_dir(), f"epoch_{self.trainer.current_epoch}"),exist_ok=True)


@rank_zero_only
def validation_step(self, batch, batch_idx, dataloader_idx=0):
"""
Run a validation step on the batch, and save that batch of images into the val_image_buffer. Once the buffer hits 1000 images, save that as a 30fps video using torchvision.io.write_video.
"""
print(f"[VAL_STEP] rank={self.global_rank}, batch_idx={batch_idx}",flush=True)

batch = self.model.process_batch_for_training(batch)
metrics, images_dict = self.model.forward_eval_logging(batch)

metrics = {
k: (v.to(self.device) if torch.is_tensor(v) else torch.tensor(v, device=self.device))
for k, v in metrics.items()
}

## images is now a dict
for key, images in images_dict.items():
os.makedirs(
Expand All @@ -117,10 +121,10 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
self.val_image_buffer[key].clear()
self.val_counter[key] += 1

self.log_dict(metrics)
self.log_dict(metrics, sync_dist=True)

@rank_zero_only
def on_validation_end(self):
print(f"[ON_VALIDATION_END] rank={self.global_rank}",flush=True)
for key, buffer in self.val_image_buffer.items():
os.makedirs(
os.path.join(
Expand Down