diff --git a/.gitignore b/.gitignore index 553218f..989fc3e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__/ *.so /segmentation/* /registration/* +/datasets/* debug/nerf_output.txt @@ -32,6 +33,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +.vscode/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/README.md b/README.md index 430c4a9..e549737 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,7 @@ Download pretrained weights cd .. # Download into grounded_sam wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth +wget https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth ``` Install SAM-HQ diff --git a/fruit_nerf/data/fruit_datamanager.py b/fruit_nerf/data/fruit_datamanager.py index e9c0faf..0394595 100644 --- a/fruit_nerf/data/fruit_datamanager.py +++ b/fruit_nerf/data/fruit_datamanager.py @@ -1,32 +1,21 @@ """ -Fruit tamanager. +FruitDataManager implementation. """ from __future__ import annotations from dataclasses import dataclass, field -from typing import ( - Dict, - Literal, - Optional, - Tuple, - Type, - Union, -) +from typing import Dict, Literal, Optional, Tuple, Type, Union import torch +from torch.nn import Parameter from typing_extensions import TypeVar from nerfstudio.cameras.rays import RayBundle - from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs -from nerfstudio.data.pixel_samplers import ( - PixelSampler, -) - +from nerfstudio.data.pixel_samplers import PixelSampler from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig - from fruit_nerf.components.ray_generators import OrthographicRayGenerator from fruit_nerf.data.fruit_dataset import FruitDataset @@ -80,21 +69,14 @@ def sample_surface_points(aabb, n, device, noise=False): Returns: torch tensor: Tensor of shape (num_points, 3) containing the sampled 3D coordinates. """ - # select three corners (must be adjacent!) corner_1 = aabb[0] # x corner_2 = aabb[1] # y corner_3 = aabb[2] # z - # Check if elements are to far away (check if adjacent) - # assert torch.abs(torch.sum(corner_1 - corner_2)) == 2.0 - # assert torch.abs(torch.sum(corner_1 - corner_3)) == 2.0 - dx_y_z = torch.abs(torch.max(aabb, axis=0).values - torch.min(aabb, axis=0).values) - # Part where the coordinate does not change constant_axis_part_pos = int(torch.argmax(torch.logical_and((corner_1 == corner_2), (corner_2 == corner_3)).to(int))) - # Generate meshgrid along XY plane start_x_pos = torch.argmax(torch.abs(corner_1 - corner_2)) x = torch.linspace(corner_1[start_x_pos], corner_2[start_x_pos], int(dx_y_z[0] / dx_y_z[constant_axis_part_pos] * n), dtype=torch.float32, device=device) @@ -104,13 +86,11 @@ def sample_surface_points(aabb, n, device, noise=False): xx, yy = torch.meshgrid(x, y) - # Flatten the meshgrid and set Z coordinate to the minimum Z value of the AABB surface_points = torch.column_stack( (xx.flatten(), yy.flatten(), torch.full_like(xx.flatten(), corner_3[constant_axis_part_pos]))) - # Convert to torch tensor surface_points_tensor = surface_points.clone() corner_4 = aabb[-1] @@ -122,17 +102,7 @@ def sample_surface_points(aabb, n, device, noise=False): class FruitDataManager(VanillaDataManager): - """Basic stored data manager implementation. - - This is pretty much a port over from our old dataloading utilities, and is a little jank - under the hood. We may clean this up a little bit under the hood with more standard dataloading - components that can be strung together, but it can be just used as a black box for now since - only the constructor is likely to change in the future, or maybe passing in step number to the - next_train and next_eval functions. - - Args: - config: the DataManagerConfig used to instantiate class - """ + """FruitDataManager implementation.""" config: FruitDataManagerConfig train_dataset: TDataset diff --git a/fruit_nerf/fruit_field.py b/fruit_nerf/fruit_field.py index d40ee40..b982a6c 100644 --- a/fruit_nerf/fruit_field.py +++ b/fruit_nerf/fruit_field.py @@ -35,7 +35,7 @@ ) from nerfstudio.field_components.mlp import MLP from nerfstudio.field_components.spatial_distortions import SpatialDistortion -from nerfstudio.fields.base_field import Field, shift_directions_for_tcnn +from nerfstudio.fields.base_field import Field, get_normalized_directions from fruit_nerf.components.field_heads import SemanticFieldHead @@ -206,7 +206,7 @@ def get_inference_outputs( outputs[FieldHeadNames.SEMANTICS] = self.field_head_semantics(x) if render_rgb: - directions = shift_directions_for_tcnn(ray_samples.frustums.directions) + directions = get_normalized_directions(ray_samples.frustums.directions) directions_flat = directions.view(-1, 3) d = self.direction_encoding(directions_flat) outputs_shape = ray_samples.frustums.directions.shape[:-1] @@ -241,7 +241,7 @@ def get_outputs( if ray_samples.camera_indices is None: raise AttributeError("Camera indices are not provided.") camera_indices = ray_samples.camera_indices.squeeze() - directions = shift_directions_for_tcnn(ray_samples.frustums.directions) + directions = get_normalized_directions(ray_samples.frustums.directions) directions_flat = directions.view(-1, 3) d = self.direction_encoding(directions_flat) diff --git a/fruit_nerf/fruit_nerf.py b/fruit_nerf/fruit_nerf.py index 225201c..bb0d6e8 100644 --- a/fruit_nerf/fruit_nerf.py +++ b/fruit_nerf/fruit_nerf.py @@ -1,5 +1,5 @@ """ -FruitNeRF implementation . +FruitNeRF implementation. """ from __future__ import annotations @@ -10,10 +10,9 @@ import numpy as np import torch from torch.nn import Parameter -from torchmetrics import PeakSignalNoiseRatio +from torchmetrics.image import PeakSignalNoiseRatio from torchmetrics.functional import structural_similarity_index_measure from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -from torchmetrics import JaccardIndex from nerfstudio.cameras.rays import RayBundle from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation @@ -40,9 +39,8 @@ from nerfstudio.utils import colormaps from nerfstudio.data.dataparsers.base_dataparser import Semantics from nerfstudio.models.nerfacto import NerfactoModelConfig -from nerfstudio.field_components.encodings import NeRFEncoding -from fruit_nerf.fruit_field import FruitField, SemanticNeRFField +from fruit_nerf.fruit_field import FruitField from fruit_nerf.components.ray_samplers import UniformSamplerWithNoise @@ -59,11 +57,7 @@ class FruitNerfModelConfig(NerfactoModelConfig): class FruitModel(Model): - """FruitModel based on Nerfacto model - - Args: - config: FruitModel configuration to instantiate model - """ + """FruitModel based on Nerfacto model""" config: FruitNerfModelConfig @@ -78,10 +72,7 @@ def populate_modules(self): """Set the fields and modules.""" super().populate_modules() - if self.config.disable_scene_contraction: - scene_contraction = None - else: - scene_contraction = SceneContraction(order=float("inf")) + scene_contraction = SceneContraction(order=float("inf")) if not self.config.disable_scene_contraction else None # Fields self.field = FruitField( @@ -100,10 +91,11 @@ def populate_modules(self): num_semantic_classes=1, pass_semantic_gradients=self.config.pass_semantic_gradients, ) - # Build the proposal network(s) + self.density_fns = [] - num_prop_nets = self.config.num_proposal_iterations self.proposal_networks = torch.nn.ModuleList() + num_prop_nets = self.config.num_proposal_iterations + if self.config.use_same_proposal_network: assert len(self.config.proposal_net_args_list) == 1, "Only one proposal network is allowed." prop_net_args = self.config.proposal_net_args_list[0] @@ -125,7 +117,7 @@ def populate_modules(self): implementation=self.config.implementation, ) self.proposal_networks.append(network) - self.density_fns.extend([network.density_fn for network in self.proposal_networks]) + self.density_fns.append(network.density_fn) def update_schedule(step): return np.clip( @@ -134,18 +126,13 @@ def update_schedule(step): self.config.proposal_update_every, ) - # Samplers - # Change proposal network initial sampler if uniform - initial_sampler = None # None is for piecewise as default (see ProposalNetworkSampler) + initial_sampler = None if self.test_mode == 'inference': - self.num_inference_samples = None # int(200) - self.proposal_sampler = None # UniformSamplerWithNoise(num_samples=self.num_inference_samples, single_jitter=True) + self.num_inference_samples = None + self.proposal_sampler = None self.field.spatial_distortion = None elif self.config.proposal_initial_sampler == "uniform": - # Change proposal network initial sampler if uniform - initial_sampler = None # None is for piecewise as default (see ProposalNetworkSampler) - if self.config.proposal_initial_sampler == "uniform": - initial_sampler = UniformSampler(single_jitter=self.config.use_single_jitter) + initial_sampler = UniformSampler(single_jitter=self.config.use_single_jitter) else: self.proposal_sampler = ProposalNetworkSampler( num_nerf_samples_per_ray=self.config.num_nerf_samples_per_ray, @@ -156,96 +143,51 @@ def update_schedule(step): initial_sampler=initial_sampler, ) - # Collider self.collider = NearFarCollider(near_plane=self.config.near_plane, far_plane=self.config.far_plane) - # renderers + # Renderers self.renderer_rgb = RGBRenderer(background_color=self.config.background_color) self.renderer_accumulation = AccumulationRenderer() self.renderer_depth = DepthRenderer() self.renderer_uncertainty = UncertaintyRenderer() self.renderer_semantics = SemanticRenderer() - # losses + # Losses self.rgb_loss = MSELoss() self.binary_cross_entropy_loss = torch.nn.BCEWithLogitsLoss(reduction="mean") - # metrics + # Metrics self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) - def setup_inference(self, render_rgb, num_inference_samples): - self.render_rgb = render_rgb # True - self.num_inference_samples = num_inference_samples # int(200) - self.proposal_sampler = UniformSamplerWithNoise(num_samples=self.num_inference_samples, single_jitter=False) - self.field.spatial_distortion = None - def get_param_groups(self) -> Dict[str, List[Parameter]]: - param_groups = {} - param_groups["proposal_networks"] = list(self.proposal_networks.parameters()) - param_groups["fields"] = list(self.field.parameters()) - return param_groups - - def get_training_callbacks( - self, training_callback_attributes: TrainingCallbackAttributes - ) -> List[TrainingCallback]: - callbacks = [] - if self.config.use_proposal_weight_anneal: - # anneal the weights of the proposal network before doing PDF sampling - N = self.config.proposal_weights_anneal_max_num_iters - - def set_anneal(step): - # https://arxiv.org/pdf/2111.12077.pdf eq. 18 - train_frac = np.clip(step / N, 0, 1) - - def bias(x, b): - return b * x / ((b - 1) * x + 1) - - anneal = bias(train_frac, self.config.proposal_weights_anneal_slope) - self.proposal_sampler.set_anneal(anneal) - - callbacks.append( - TrainingCallback( - where_to_run=[TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], - update_every_num_iters=1, - func=set_anneal, - ) - ) - callbacks.append( - TrainingCallback( - where_to_run=[TrainingCallbackLocation.AFTER_TRAIN_ITERATION], - update_every_num_iters=1, - func=self.proposal_sampler.step_cb, - ) - ) - return callbacks + return { + "proposal_networks": list(self.proposal_networks.parameters()), + "fields": list(self.field.parameters()) + } def get_inference_outputs(self, ray_bundle: RayBundle, render_rgb: bool = False): outputs = {} - ray_samples = self.proposal_sampler(ray_bundle) field_outputs = self.field.forward(ray_samples, render_rgb=render_rgb) if render_rgb: outputs["rgb"] = field_outputs[FieldHeadNames.RGB] - outputs['point_location'] = ray_samples.frustums.get_positions() + outputs["point_location"] = ray_samples.frustums.get_positions() outputs["semantics"] = field_outputs[FieldHeadNames.SEMANTICS][..., 0] outputs["density"] = field_outputs[FieldHeadNames.DENSITY][..., 0] semantic_labels = torch.sigmoid(outputs["semantics"]) threshold = 0.9 semantic_labels = torch.heaviside(semantic_labels - threshold, torch.tensor(0.)).to(torch.long) - outputs["semantics_colormap"] = semantic_labels return outputs - def get_outputs(self, ray_bundle: RayBundle): # - + def get_outputs(self, ray_bundle: RayBundle): ray_samples, weights_list, ray_samples_list = self.proposal_sampler(ray_bundle, density_fns=self.density_fns) - field_outputs = self.field.forward(ray_samples) if self.config.use_gradient_scaling: @@ -259,25 +201,25 @@ def get_outputs(self, ray_bundle: RayBundle): # depth = self.renderer_depth(weights=weights, ray_samples=ray_samples) accumulation = self.renderer_accumulation(weights=weights) - outputs = {"rgb": rgb, "accumulation": accumulation, "depth": depth, "weights_list": weights_list, - "ray_samples_list": ray_samples_list} + outputs = { + "rgb": rgb, + "accumulation": accumulation, + "depth": depth, + "weights_list": weights_list, + "ray_samples_list": ray_samples_list + } for i in range(self.config.num_proposal_iterations): outputs[f"prop_depth_{i}"] = self.renderer_depth(weights=weights_list[i], ray_samples=ray_samples_list[i]) - # semantics - semantic_weights = weights - if not self.config.pass_semantic_gradients: - semantic_weights = semantic_weights.detach() - outputs["semantics"] = self.renderer_semantics( - field_outputs[FieldHeadNames.SEMANTICS], weights=semantic_weights - ) + # Semantics + semantic_weights = weights if self.config.pass_semantic_gradients else weights.detach() + outputs["semantics"] = self.renderer_semantics(field_outputs[FieldHeadNames.SEMANTICS], weights=semantic_weights) - # semantics colormaps + # Semantics colormaps semantic_labels = torch.sigmoid(outputs["semantics"].detach()) threshold = 0.9 semantic_labels = torch.heaviside(semantic_labels - threshold, torch.tensor(0.)).to(torch.long) - outputs["semantics_colormap"] = self.colormap.to(self.device)[semantic_labels] return outputs @@ -285,8 +227,23 @@ def get_outputs(self, ray_bundle: RayBundle): # def get_loss_dict(self, outputs, batch, metrics_dict=None): loss_dict = {} image = batch["image"].to(self.device) - loss_dict["rgb_loss"] = self.rgb_loss(image, outputs["rgb"]) - + + # Ensure outputs_rgb is always defined + outputs_rgb = outputs["rgb"] + + # Handle channel mismatch + if outputs["rgb"].shape[-1] != image.shape[-1]: + if outputs["rgb"].shape[-1] == 4 and image.shape[-1] == 3: + outputs_rgb = outputs["rgb"][..., :3] # Use only the RGB channels from the output + elif outputs["rgb"].shape[-1] == 3 and image.shape[-1] == 4: + image = image[..., :3] # Use only the RGB channels from the ground truth + else: + raise ValueError(f"Unexpected channel size in tensors: outputs['rgb'] shape {outputs['rgb'].shape}, image shape {image.shape}") + + # Calculate RGB loss + loss_dict["rgb_loss"] = self.rgb_loss(image, outputs_rgb) + + # Other loss calculations can go here loss_dict["semantics_loss"] = self.config.semantic_loss_weight * self.binary_cross_entropy_loss( outputs["semantics"], batch["fruit_mask"] ) @@ -294,33 +251,31 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None): loss_dict["interlevel_loss"] = self.config.interlevel_loss_mult * interlevel_loss( outputs["weights_list"], outputs["ray_samples_list"] ) - + return loss_dict - def forward(self, ray_bundle: RayBundle) -> Dict[str, Union[torch.Tensor, List]]: - """Run forward starting with a ray bundle. This outputs different things depending on the configuration - of the model and whether or not the batch is provided (whether or not we are training basically) - - Args: - ray_bundle: containing all the information needed to render that ray latents included - """ - - if self.collider is not None: - ray_bundle = self.collider(ray_bundle) - - if self.test_mode == 'inference': - # fruit_nerf_output = self.get_inference_outputs(ray_bundle, self.render_rgb) - fruit_nerf_output = self.get_inference_outputs(ray_bundle, True) - else: - fruit_nerf_output = self.get_outputs(ray_bundle) - - return fruit_nerf_output - def get_metrics_dict(self, outputs, batch): metrics_dict = {} image = batch["image"].to(self.device) - metrics_dict["psnr"] = self.psnr(outputs["rgb"], image) + + # Ensure outputs_rgb is always defined + outputs_rgb = outputs["rgb"] + + # Handle channel mismatch + if outputs["rgb"].shape[-1] != image.shape[-1]: + if outputs["rgb"].shape[-1] == 4 and image.shape[-1] == 3: + outputs_rgb = outputs["rgb"][..., :3] # Use only the RGB channels from the output + elif outputs["rgb"].shape[-1] == 3 and image.shape[-1] == 4: + image = image[..., :3] # Use only the RGB channels from the ground truth + else: + raise ValueError(f"Unexpected channel size in tensors: outputs['rgb'] shape {outputs['rgb'].shape}, image shape {image.shape}") + + # Compute PSNR + metrics_dict["psnr"] = self.psnr(outputs_rgb, image) + + # Compute distortion loss metrics_dict["distortion"] = distortion_loss(outputs["weights_list"], outputs["ray_samples_list"]) + return metrics_dict def get_image_metrics_and_images( @@ -362,15 +317,11 @@ def get_image_metrics_and_images( images_dict[key] = prop_depth_i # semantics - # semantic_labels = torch.argmax(torch.nn.functional.softmax(outputs["semantics"], dim=-1), dim=-1) semantic_labels = torch.sigmoid(outputs["semantics"]) - images_dict[ - "semantics_colormap"] = semantic_labels + images_dict["semantics_colormap"] = semantic_labels # valid mask images_dict["fruit_mask"] = batch["fruit_mask"].repeat(1, 1, 3).to(self.device) - # batch["fruit_mask"][batch["fruit_mask"] < 0.1] = 0 - # batch["fruit_mask"][batch["fruit_mask"] >= 0.1] = 1 from torchmetrics.classification import BinaryJaccardIndex metric = BinaryJaccardIndex().to(self.device) @@ -379,7 +330,3 @@ def get_image_metrics_and_images( metrics_dict["iou"] = float(iou) return metrics_dict, images_dict - - -class FruitModelMLP(Model): - pass diff --git a/fruit_nerf/fruit_pipeline.py b/fruit_nerf/fruit_pipeline.py index 7f5b66b..085cb00 100644 --- a/fruit_nerf/fruit_pipeline.py +++ b/fruit_nerf/fruit_pipeline.py @@ -1,54 +1,31 @@ -# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """ -Abstracts for the Pipeline class. +FruitPipeline implementation. """ + from __future__ import annotations -import typing -from abc import abstractmethod from dataclasses import dataclass, field +import torch.nn +from typing import Any, Dict, Literal, Optional, Tuple, Type, List +from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, MofNCompleteColumn +from PIL import Image +from time import time from pathlib import Path -from time import time -from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Type, Union, cast + import torch -import torch.distributed as dist -from PIL import Image -from rich.progress import ( - BarColumn, - MofNCompleteColumn, - Progress, - TextColumn, - TimeElapsedColumn, -) from torch import nn from torch.nn import Parameter -from torch.nn.parallel import DistributedDataParallel as DDP from torch.cuda.amp.grad_scaler import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist from nerfstudio.configs import base_config as cfg -from nerfstudio.data.datamanagers.base_datamanager import ( - DataManager, - DataManagerConfig, - VanillaDataManager, -) +from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.models.base_model import Model, ModelConfig -from nerfstudio.utils import profiler from nerfstudio.pipelines.base_pipeline import Pipeline +from nerfstudio.utils import profiler @dataclass @@ -56,31 +33,12 @@ class FruitPipelineConfig(cfg.InstantiateConfig): """Configuration for pipeline instantiation""" _target: Type = field(default_factory=lambda: FruitPipeline) - """target class to instantiate""" datamanager: DataManagerConfig = DataManagerConfig() - """specifies the datamanager config""" model: ModelConfig = ModelConfig() - """specifies the model config""" class FruitPipeline(Pipeline): - """The pipeline class for the vanilla nerf setup of multiple cameras for one or a few scenes. - - Args: - config: configuration to instantiate pipeline - device: location to place model and data - test_mode: - 'val': loads train/val datasets into memory - 'test': loads train/test dataset into memory - 'inference': does not load any dataset into memory - world_size: total number of machines available - local_rank: rank of current machine - grad_scaler: gradient scaler used in the trainer - - Attributes: - datamanager: The data manager that will be used - model: The model that will be used - """ + """The pipeline class for the vanilla nerf setup of multiple cameras for one or a few scenes.""" def __init__( self, @@ -98,7 +56,7 @@ def __init__( device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank ) self.datamanager.to(device) - # TODO(ethan): get rid of scene_bounds from the model + assert self.datamanager.train_dataset is not None, "Missing input dataset" self._model = config.model.setup( @@ -114,78 +72,50 @@ def __init__( self.world_size = world_size if world_size > 1: - self._model = typing.cast(Model, DDP(self._model, device_ids=[local_rank], find_unused_parameters=True)) + self._model = DDP(self._model, device_ids=[local_rank], find_unused_parameters=True) dist.barrier(device_ids=[local_rank]) @profiler.time_function def get_train_loss_dict(self, step: int): - """This function gets your training loss dict. This will be responsible for - getting the next batch of data from the DataManager and interfacing with the - Model class, feeding the data to the model's forward function. + """This function gets your training loss dict.""" - Args: - step: current iteration step to update sampler if using DDP (distributed) - """ ray_bundle, batch = self.datamanager.next_train(step) - model_outputs = self._model(ray_bundle) # train distributed data parallel model if world_size > 1 + model_outputs = self._model(ray_bundle) + metrics_dict = self.model.get_metrics_dict(model_outputs, batch) - if self.config.datamanager.camera_optimizer is not None: - camera_opt_param_group = self.config.datamanager.camera_optimizer.param_group - if camera_opt_param_group in self.datamanager.get_param_groups(): - # Report the camera optimization metrics - metrics_dict["camera_opt_translation"] = ( - self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, :3].norm() - ) - metrics_dict["camera_opt_rotation"] = ( - self.datamanager.get_param_groups()[camera_opt_param_group][0].data[:, 3:].norm() - ) + if "camera_opt" in self.model.get_param_groups(): + param_group = self.model.get_param_groups()["camera_opt"] + metrics_dict["camera_opt_translation"] = param_group[0].data[:, :3].norm() + metrics_dict["camera_opt_rotation"] = param_group[0].data[:, 3:].norm() loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict) return model_outputs, loss_dict, metrics_dict def forward(self): - """Blank forward method - - This is an nn.Module, and so requires a forward() method normally, although in our case - we do not need a forward() method""" + """Blank forward method""" raise NotImplementedError - @profiler.time_function def get_eval_image_metrics_and_images(self, step: int): - """This function gets your evaluation loss dict. It needs to get the data - from the DataManager and feed it to the model's forward function + """This function gets your evaluation loss dict.""" - Args: - step: current iteration step - """ self.eval() image_idx, camera_ray_bundle, batch = self.datamanager.next_eval_image(step) outputs = self.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle) metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch) - assert "image_idx" not in metrics_dict metrics_dict["image_idx"] = image_idx - assert "num_rays" not in metrics_dict metrics_dict["num_rays"] = len(camera_ray_bundle) self.train() return metrics_dict, images_dict @profiler.time_function def get_average_eval_image_metrics(self, step: Optional[int] = None, output_path: Optional[Path] = None): - """Iterate over all the images in the eval dataset and get the average. + """Iterate over all the images in the eval dataset and get the average.""" - Args: - step: current training step - output_path: optional path to save rendered images to - - Returns: - metrics_dict: dictionary of metrics - """ self.eval() metrics_dict_list = [] - assert isinstance(self.datamanager, VanillaDataManager) num_images = len(self.datamanager.fixed_indices_eval_dataloader) with Progress( TextColumn("[progress.description]{task.description}"), @@ -196,7 +126,6 @@ def get_average_eval_image_metrics(self, step: Optional[int] = None, output_path ) as progress: task = progress.add_task("[green]Evaluating all eval images...", total=num_images) for camera_ray_bundle, batch in self.datamanager.fixed_indices_eval_dataloader: - # time this the following line inner_start = time() height, width = camera_ray_bundle.shape num_rays = height * width @@ -205,19 +134,15 @@ def get_average_eval_image_metrics(self, step: Optional[int] = None, output_path if output_path is not None: camera_indices = camera_ray_bundle.camera_indices - assert camera_indices is not None for key, val in images_dict.items(): Image.fromarray((val * 255).byte().cpu().numpy()).save( output_path / "{0:06d}-{1}.jpg".format(int(camera_indices[0, 0, 0]), key) ) - assert "num_rays_per_sec" not in metrics_dict metrics_dict["num_rays_per_sec"] = num_rays / (time() - inner_start) - fps_str = "fps" - assert fps_str not in metrics_dict - metrics_dict[fps_str] = metrics_dict["num_rays_per_sec"] / (height * width) + metrics_dict["fps"] = metrics_dict["num_rays_per_sec"] / (height * width) metrics_dict_list.append(metrics_dict) progress.advance(task) - # average the metrics list + metrics_dict = {} for key in metrics_dict_list[0].keys(): metrics_dict[key] = float( @@ -227,12 +152,8 @@ def get_average_eval_image_metrics(self, step: Optional[int] = None, output_path return metrics_dict def load_pipeline(self, loaded_state: Dict[str, Any], step: int) -> None: - """Load the checkpoint from the given path + """Load the checkpoint from the given path""" - Args: - loaded_state: pre-trained model state dict - step: training step of the loaded checkpoint - """ state = { (key[len("module."):] if key.startswith("module.") else key): value for key, value in loaded_state.items() } @@ -249,12 +170,8 @@ def get_training_callbacks( return callbacks def get_param_groups(self) -> Dict[str, List[Parameter]]: - """Get the param groups for the pipeline. + """Get the param groups for the pipeline.""" - Returns: - A list of dictionaries containing the pipeline's param groups. - """ datamanager_params = self.datamanager.get_param_groups() model_params = self.model.get_param_groups() - # TODO(ethan): assert that key names don't overlap return {**datamanager_params, **model_params} diff --git a/pyproject.toml b/pyproject.toml index c9fcb85..6db6425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,15 +7,15 @@ dependencies=[ "numpy", "tqdm", "rawpy", - "nerfstudio==0.3.2", - "supervision==0.18", - "timm==0.9.2", - "pyransac3d==0.6.0", - "alphashape==1.3.1", - "robust_laplacian==0.2.7", - "polyscope==2.2.1", - "hausdorff==0.2.6", - "numba==0.58.1", + "nerfstudio>=1.1.4", + "supervision>=0.23.0", + "timm>=0.6.7", + "pyransac3d>=0.6.0", + "alphashape>=1.3.1", + "robust_laplacian>=0.2.7", + "polyscope>=2.2.1", + "hausdorff>=0.2.6", + "numba>=0.58.1", ] [tool.setuptools.packages.find]