diff --git a/fruit_nerf/fruit_nerf.py b/fruit_nerf/fruit_nerf.py index 004db82..3322071 100644 --- a/fruit_nerf/fruit_nerf.py +++ b/fruit_nerf/fruit_nerf.py @@ -404,7 +404,7 @@ def get_image_metrics_and_images( self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: image = batch["image"].to(self.device) - rgb = outputs["rgb"] + rgb = outputs["rgb"].to(self.device) rgb = torch.clamp(rgb, min=0, max=1) acc = colormaps.apply_colormap(outputs["accumulation"]) depth = colormaps.apply_depth_colormap( @@ -452,7 +452,7 @@ def get_image_metrics_and_images( from torchmetrics.classification import BinaryJaccardIndex metric = BinaryJaccardIndex().to(self.device) semantic_labels = torch.nn.functional.softmax(outputs["semantics"]) - iou = metric(semantic_labels[..., 0], batch["fruit_mask"][..., 0]) + iou = metric(semantic_labels[..., 0].to(self.device), batch["fruit_mask"][..., 0]) metrics_dict["iou"] = float(iou) return metrics_dict, images_dict