diff --git a/GET3D_NADA/README.md b/GET3D_NADA/README.md
new file mode 100644
index 0000000..81fd7a7
--- /dev/null
+++ b/GET3D_NADA/README.md
@@ -0,0 +1,114 @@
+# Text-guided 3D synthesis by GET3D + NADA
+
+| | |
+|-----------------------------------------------------|---------------------------------------------|
+| Car → Police | Car → Sketch |
+|  |  |
+| Motorbike → Tiger | Shoe → Mossy |
+|  |  |
+
+> You can make any other interesting stylish 3D object!
+
+
+
+---
+
+## Requirement setup
+
+* First, set GET3D environment (venv or docker image).
+* Then, install extra requirements by `pip install -r extra_requirements.txt`
+
+
+
+---
+
+## Download checkpoints
+
+For GET3D + NADA, you need pretrained model's checkpoint. You can set downloaded ckpt path at yaml file.
+
+- Car, Chair, Table, Motorbike → [link](https://github.com/nv-tlabs/GET3D/tree/master/pretrained_model)
+
+- Fruits, Shoe → [link](https://huggingface.co/datasets/allenai/objaverse/discussions/1#63c0441bd9e14fd8875cec97)
+
+
+
+---
+
+## Train
+
+### Train code
+
+If you want to train the code, please refer to the training script below.
+
+```
+$ python train_nada.py --config_path='experiments/{}.yaml' --name='{}' --suppress
+
+optional arguments
+ --config_path select yaml file to run (in experiments folder)
+ --name choose any name you want for log file name (optional)
+ --suppress store only latest & best pkl file
+
+EX)
+$ python train_nada.py --config_path='experiments/car_police_example.yaml' --name='car_police' --suppress
+```
+
+
+
+### Trainable Parameters
+
+When you open yaml file, you could see many trainable parameters and configs.
+
+Among them, below are some important parameters you could change as you conduct an experiment.
+
+We provide some yaml files as [examples](./experiments).
+
+
+
+**Global Config**
+
+| | Default Setting | Detailed explanation |
+|----------|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| batch | 3 | Setting the batch number less than 3 resulted unfavorable results in most of the experiments. However, you could change this value to some other value that fits well to your experiments |
+| iter_1st | 1 | For most of the cases, 1 was enough to generate 3d object you want. You could increase this value to see more changes in the generated objects |
+| iter_2nd | 30 | For most of the cases, since model converges after iter_1st, 1 was enough to generate 3d object you want. You could increase this value to see more changes in the generated objects |
+
+
+
+**GET3D config**
+
+| | Default Setting | Detailed explanation |
+|---------|-----------------|--------------------------------------------------------------------------------------------|
+| n_views | 12 | You can change this value that fits your GPU memory. According to Paper, set n_views >= 16 |
+
+
+
+**NADA config**
+
+| | Default Setting | Detailed explanation |
+|-------------------------|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| lr | 0.002 | For most of the experiments lr 0.002 was suitable. However, you can change this value that fits your task. |
+| auto_layer_k | 20 , 30 | auto_layer_k means the number of trainable layers during adaptation of GET3D. We empirically found some following tips. For texture changes + slight shape changes, setting the auto_layer_k to 20 was suitable. For only texture changes, setting the auto_layer_k to 30 was suitable. |
+| source text | pretrained object | For most of the experiments, we simply set source text to pretrained object. However, we found out that giving some text prompt to this variable showed some improvements in some cases. EX) 3D object car |
+| target text | target object | For most of the experiments, we simply set target text to target object. However, we found out that giving some text prompt to this variable showed some improvements in some cases. EX) 3D render in the style of Pixar |
+| gradient_clip_threshold | -1 | For most of the experiments, not using gradient_clip(set as -1) was suitable. However, if the task requires some major changes in shape, using gradient clip was helpful. |
+
+
+
+---
+
+## Inference
+
+* You can run inference with nada checkpoint, by same code(../train_3d.py).
+
+---
+
+## Appendix
+
+### CLIP util
+* We provide `clip_save.py` to avoid 'connection reset by peer' error from CLIP library, which accidentally stops the runtime.
+
+1. Do `python clip_save.py`, and then you can get `clip-cnn.pt` / `clip-vit-b-16.pt` / `clip-vit-b-32.pt`
+2. Change `clip.load()` argument as follows (Note that this is used at `clip_loss.py`)
+ - `clip.load('RN50')` → `clip.load('/PATH/TO/clip-cnn.pt')`
+ - `clip.load('ViT-B/16')` → `clip.load('/PATH/TO/clip-vit-b-16.pt')`
+ - `clip.load('ViT-B/32')` → `clip.load('/PATH/TO/clip-vit-b-32.pt')`
diff --git a/GET3D_NADA/assets/nada_car_police.gif b/GET3D_NADA/assets/nada_car_police.gif
new file mode 100644
index 0000000..6cb8c69
Binary files /dev/null and b/GET3D_NADA/assets/nada_car_police.gif differ
diff --git a/GET3D_NADA/assets/nada_car_sketch.gif b/GET3D_NADA/assets/nada_car_sketch.gif
new file mode 100644
index 0000000..01ae525
Binary files /dev/null and b/GET3D_NADA/assets/nada_car_sketch.gif differ
diff --git a/GET3D_NADA/assets/nada_motorbike_tiger.gif b/GET3D_NADA/assets/nada_motorbike_tiger.gif
new file mode 100644
index 0000000..ded81e8
Binary files /dev/null and b/GET3D_NADA/assets/nada_motorbike_tiger.gif differ
diff --git a/GET3D_NADA/assets/nada_shoes_mossy.gif b/GET3D_NADA/assets/nada_shoes_mossy.gif
new file mode 100644
index 0000000..7eb8a2c
Binary files /dev/null and b/GET3D_NADA/assets/nada_shoes_mossy.gif differ
diff --git a/GET3D_NADA/clip_loss.py b/GET3D_NADA/clip_loss.py
new file mode 100644
index 0000000..cbf8fef
--- /dev/null
+++ b/GET3D_NADA/clip_loss.py
@@ -0,0 +1,362 @@
+# Copyright (c) 2021 rinongal
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+"""
+Loss-class with CLIP
+
+Reference: https://github.com/rinongal/StyleGAN-nada/blob/dc8406ae2173ad186f8f03f3cadf65e613ac9364/ZSSGAN/criteria/clip_loss.py
+"""
+
+import torch
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+
+import numpy as np
+
+import clip
+from PIL import Image
+from text_templates import imagenet_templates, part_templates
+
+
+class DirectionLoss(torch.nn.Module):
+
+ def __init__(self, loss_type='mse'):
+ super(DirectionLoss, self).__init__()
+
+ self.loss_type = loss_type
+
+ self.loss_func = {
+ 'mse': torch.nn.MSELoss,
+ 'cosine': torch.nn.CosineSimilarity,
+ 'mae': torch.nn.L1Loss
+ }[loss_type]()
+
+ def forward(self, x, y):
+ if self.loss_type == "cosine":
+ return 1. - self.loss_func(x, y)
+
+ return self.loss_func(x, y)
+
+
+class CLIPLoss(torch.nn.Module):
+ def __init__(self, device, lambda_direction=1., lambda_patch=0., lambda_global=0., lambda_manifold=0., lambda_texture=0., patch_loss_type='mae', direction_loss_type='cosine', clip_model='ViT-B/32'):
+ super(CLIPLoss, self).__init__()
+
+ self.device = device
+ self.model, clip_preprocess = clip.load(clip_model, device=self.device)
+
+ self.clip_preprocess = clip_preprocess
+
+ self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
+ clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
+ clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
+
+ self.target_direction = None
+ self.patch_text_directions = None
+
+ self.patch_loss = DirectionLoss(patch_loss_type)
+ self.direction_loss = DirectionLoss(direction_loss_type)
+ self.patch_direction_loss = torch.nn.CosineSimilarity(dim=2)
+
+ self.lambda_global = lambda_global
+ self.lambda_patch = lambda_patch
+ self.lambda_direction = lambda_direction
+ self.lambda_manifold = lambda_manifold
+ self.lambda_texture = lambda_texture
+
+ self.src_text_features = None
+ self.target_text_features = None
+ self.angle_loss = torch.nn.L1Loss()
+
+ self.model_cnn, preprocess_cnn = clip.load("RN50", device=self.device)
+ # self.model_cnn, preprocess_cnn = clip.load("./clip-cnn.pt", device=self.device)
+ self.preprocess_cnn = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
+ preprocess_cnn.transforms[:2] + # to match CLIP input scale assumptions
+ preprocess_cnn.transforms[4:]) # + skip convert PIL to tensor
+
+ self.model.requires_grad_(False)
+ self.model_cnn.requires_grad_(False)
+
+ self.texture_loss = torch.nn.MSELoss()
+
+ def tokenize(self, strings: list):
+ return clip.tokenize(strings).to(self.device)
+
+ def encode_text(self, tokens: list) -> torch.Tensor:
+ return self.model.encode_text(tokens)
+
+ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
+ images = self.preprocess(images).to(self.device)
+ return self.model.encode_image(images)
+
+ def encode_images_with_cnn(self, images: torch.Tensor) -> torch.Tensor:
+ images = self.preprocess_cnn(images).to(self.device)
+ return self.model_cnn.encode_image(images)
+
+ def distance_with_templates(self, img: torch.Tensor, class_str: str, templates=imagenet_templates) -> torch.Tensor:
+
+ text_features = self.get_text_features(class_str, templates)
+ image_features = self.get_image_features(img)
+
+ similarity = image_features @ text_features.T
+
+ return 1. - similarity
+
+ def get_text_features(self, class_str: str, templates=imagenet_templates, norm: bool = True) -> torch.Tensor:
+ template_text = self.compose_text_with_templates(class_str, templates)
+
+ tokens = clip.tokenize(template_text).to(self.device)
+
+ text_features = self.encode_text(tokens).detach()
+
+ if norm:
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+
+ return text_features
+
+ def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
+ image_features = self.encode_images(img)
+
+ if norm:
+ image_features /= image_features.clone().norm(dim=-1, keepdim=True)
+
+ return image_features
+
+ def preprocessing_image(self, img) : # added PIL to tensor
+ preprocessed = self.clip_preprocess(Image.open(img)).unsqueeze(0).to(self.device)
+ encoding = self.model.encode_image(preprocessed)
+ encoding /= encoding.norm(dim=-1, keepdim=True)
+
+ return encoding
+
+ def compute_loss(self, source_class: torch.Tensor, target_class: torch.Tensor):
+ loss = F.cosine_similarity(source_class, target_class)
+ loss = 1-loss
+ loss = torch.mean(loss)
+
+ return loss
+
+ #templated mean text feature
+ def templated_mean_text(self, source_class: str) :
+ source_features = self.get_text_features(source_class).mean(axis=0, keepdim=True)
+
+ return source_features
+
+ #non templated text feature
+ def non_templated_text(self, source_class: str) :
+ source_tokens = clip.tokenize(source_class).to(self.device)
+ source_features = self.encode_text(source_tokens)
+
+ return source_features
+
+ def compute_text_direction(self, source_class: str, target_class: str) -> torch.Tensor:
+ source_features = self.get_text_features(source_class)
+ target_features = self.get_text_features(target_class)
+
+ text_direction = (target_features - source_features).mean(axis=0, keepdim=True)
+ text_direction /= text_direction.norm(dim=-1, keepdim=True)
+
+ return text_direction
+
+ def compute_img2img_direction(self, source_images: torch.Tensor, target_images: list) -> torch.Tensor:
+ with torch.no_grad():
+
+ src_encoding = self.get_image_features(source_images)
+ src_encoding = src_encoding.mean(dim=0, keepdim=True)
+
+ target_encodings = []
+ for target_img in target_images:
+
+ preprocessed = self.clip_preprocess(Image.open(target_img)).unsqueeze(0).to(self.device)
+
+ encoding = self.model.encode_image(preprocessed)
+ encoding /= encoding.norm(dim=-1, keepdim=True)
+
+ target_encodings.append(encoding)
+
+ target_encoding = torch.cat(target_encodings, dim=0)
+ target_encoding = target_encoding.mean(dim=0, keepdim=True)
+
+ direction = target_encoding - src_encoding
+ direction /= direction.norm(dim=-1, keepdim=True)
+
+ return direction
+
+ def set_text_features(self, source_class: str, target_class: str) -> None:
+ source_features = self.get_text_features(source_class).mean(axis=0, keepdim=True)
+ self.src_text_features = source_features / source_features.norm(dim=-1, keepdim=True)
+
+ target_features = self.get_text_features(target_class).mean(axis=0, keepdim=True)
+ self.target_text_features = target_features / target_features.norm(dim=-1, keepdim=True)
+
+ def clip_angle_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
+ if self.src_text_features is None:
+ self.set_text_features(source_class, target_class)
+
+ cos_text_angle = self.target_text_features @ self.src_text_features.T
+ text_angle = torch.acos(cos_text_angle)
+
+ src_img_features = self.get_image_features(src_img).unsqueeze(2)
+ target_img_features = self.get_image_features(target_img).unsqueeze(1)
+
+ cos_img_angle = torch.clamp(target_img_features @ src_img_features, min=-1.0, max=1.0)
+ img_angle = torch.acos(cos_img_angle)
+
+ text_angle = text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1)
+ cos_text_angle = cos_text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1)
+
+ return self.angle_loss(cos_img_angle, cos_text_angle)
+
+ def compose_text_with_templates(self, text: str, templates=imagenet_templates) -> list:
+ return [template.format(text) for template in templates]
+
+ def clip_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
+
+ if self.target_direction is None:
+ self.target_direction = self.compute_text_direction(source_class, target_class)
+
+ src_encoding = self.get_image_features(src_img)
+ target_encoding = self.get_image_features(target_img)
+ edit_direction = (target_encoding - src_encoding)
+ if edit_direction.sum() == 0:
+ target_encoding = self.get_image_features(target_img + 1e-6)
+ edit_direction = (target_encoding - src_encoding)
+
+ edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True))
+ # return self.direction_loss(edit_direction, self.target_direction).mean() # original
+ return self.direction_loss(edit_direction, self.target_direction)
+
+ def global_clip_loss(self, img: torch.Tensor, text) -> torch.Tensor:
+ if not isinstance(text, list):
+ text = [text]
+
+ tokens = clip.tokenize(text).to(self.device)
+ image = self.preprocess(img)
+
+ logits_per_image, _ = self.model(image, tokens)
+
+ return (1. - logits_per_image / 100).mean()
+
+ def random_patch_centers(self, img_shape, num_patches, size):
+ batch_size, channels, height, width = img_shape
+
+ half_size = size // 2
+ patch_centers = np.concatenate([np.random.randint(half_size, width - half_size, size=(batch_size * num_patches, 1)),
+ np.random.randint(half_size, height - half_size, size=(batch_size * num_patches, 1))], axis=1)
+
+ return patch_centers
+
+ def generate_patches(self, img: torch.Tensor, patch_centers, size):
+ batch_size = img.shape[0]
+ num_patches = len(patch_centers) // batch_size
+ half_size = size // 2
+
+ patches = []
+
+ for batch_idx in range(batch_size):
+ for patch_idx in range(num_patches):
+
+ center_x = patch_centers[batch_idx * num_patches + patch_idx][0]
+ center_y = patch_centers[batch_idx * num_patches + patch_idx][1]
+
+ patch = img[batch_idx:batch_idx+1, :, center_y - half_size:center_y + half_size, center_x - half_size:center_x + half_size]
+
+ patches.append(patch)
+
+ patches = torch.cat(patches, dim=0)
+
+ return patches
+
+ def patch_scores(self, img: torch.Tensor, class_str: str, patch_centers, patch_size: int) -> torch.Tensor:
+
+ parts = self.compose_text_with_templates(class_str, part_templates)
+ tokens = clip.tokenize(parts).to(self.device)
+ text_features = self.encode_text(tokens).detach()
+
+ patches = self.generate_patches(img, patch_centers, patch_size)
+ image_features = self.get_image_features(patches)
+
+ similarity = image_features @ text_features.T
+
+ return similarity
+
+ def clip_patch_similarity(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
+ patch_size = 196 #TODO remove magic number
+
+ patch_centers = self.random_patch_centers(src_img.shape, 4, patch_size) #TODO remove magic number
+
+ src_scores = self.patch_scores(src_img, source_class, patch_centers, patch_size)
+ target_scores = self.patch_scores(target_img, target_class, patch_centers, patch_size)
+
+ return self.patch_loss(src_scores, target_scores)
+
+ def patch_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
+
+ if self.patch_text_directions is None:
+ src_part_classes = self.compose_text_with_templates(source_class, part_templates)
+ target_part_classes = self.compose_text_with_templates(target_class, part_templates)
+
+ parts_classes = list(zip(src_part_classes, target_part_classes))
+
+ self.patch_text_directions = torch.cat([self.compute_text_direction(pair[0], pair[1]) for pair in parts_classes], dim=0)
+
+ patch_size = 510 # TODO remove magic numbers
+
+ patch_centers = self.random_patch_centers(src_img.shape, 1, patch_size)
+
+ patches = self.generate_patches(src_img, patch_centers, patch_size)
+ src_features = self.get_image_features(patches)
+
+ patches = self.generate_patches(target_img, patch_centers, patch_size)
+ target_features = self.get_image_features(patches)
+
+ edit_direction = (target_features - src_features)
+ edit_direction /= edit_direction.clone().norm(dim=-1, keepdim=True)
+
+ cosine_dists = 1. - self.patch_direction_loss(edit_direction.unsqueeze(1), self.patch_text_directions.unsqueeze(0))
+
+ patch_class_scores = cosine_dists * (edit_direction @ self.patch_text_directions.T).softmax(dim=-1)
+
+ return patch_class_scores.mean()
+
+ def cnn_feature_loss(self, src_img: torch.Tensor, target_img: torch.Tensor) -> torch.Tensor:
+ src_features = self.encode_images_with_cnn(src_img)
+ target_features = self.encode_images_with_cnn(target_img)
+
+ return self.texture_loss(src_features, target_features)
+
+ def forward(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str, texture_image: torch.Tensor = None):
+ clip_loss = 0.0
+
+ if self.lambda_global:
+ clip_loss += self.lambda_global * self.global_clip_loss(target_img, [f"a {target_class}"])
+
+ if self.lambda_patch:
+ clip_loss += self.lambda_patch * self.patch_directional_loss(src_img, source_class, target_img, target_class)
+
+ if self.lambda_direction:
+ clip_loss += self.lambda_direction * self.clip_directional_loss(src_img, source_class, target_img, target_class)
+
+ if self.lambda_manifold:
+ clip_loss += self.lambda_manifold * self.clip_angle_loss(src_img, source_class, target_img, target_class)
+
+ if self.lambda_texture and (texture_image is not None):
+ clip_loss += self.lambda_texture * self.cnn_feature_loss(texture_image, target_img)
+
+ return clip_loss
diff --git a/GET3D_NADA/clip_save.py b/GET3D_NADA/clip_save.py
new file mode 100644
index 0000000..7bca9d2
--- /dev/null
+++ b/GET3D_NADA/clip_save.py
@@ -0,0 +1,41 @@
+"""
+To avoid 'Connection Reset by Peer' error by CLIP library, save checkpoint using torch.jit
+
+Usage
+ - $ python scripts/clip_save.py [--test] [--device ]
+"""
+
+
+def parse_args():
+ import argparse
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('-o', '--output_dir', type=str, default='.', help='output checkpoint file directory')
+ parser.add_argument('-t', '--test', action='store_true', help='to test checkpoint')
+ parser.add_argument('-d', '--device', type=str, default='cuda:0', help='device to use in testing')
+ return parser.parse_args()
+
+
+def main(args):
+ import os
+ import torch
+ import clip
+
+ torch.jit.save(clip.load('ViT-B/32', jit=True)[0], os.path.join(args.output_dir, 'clip-vit-b-32.pt'))
+ torch.jit.save(clip.load('ViT-B/16', jit=True)[0], os.path.join(args.output_dir, 'clip-vit-b-16.pt'))
+ torch.jit.save(clip.load('RN50', jit=True)[0], os.path.join(args.output_dir, 'clip-cnn.pt'))
+
+ model32, _ = clip.load(os.path.join(args.output_dir, 'clip-vit-b-32.pt'))
+ model16, _ = clip.load(os.path.join(args.output_dir, 'clip-vit-b-16.pt'))
+ modelcnn, _ = clip.load(os.path.join(args.output_dir, 'clip-cnn.pt'))
+
+ if args.test:
+ text = 'rusty car'
+ embed32 = model32.encode_text(clip.tokenize(text).to(args.device))
+ print(embed32.shape)
+ embed16 = model16.encode_text(clip.tokenize(text).to(args.device))
+ print(embed16.shape)
+ print((embed16 * embed32).sum())
+
+
+if __name__ == '__main__':
+ main(parse_args())
diff --git a/GET3D_NADA/dist_util.py b/GET3D_NADA/dist_util.py
new file mode 100644
index 0000000..966306d
--- /dev/null
+++ b/GET3D_NADA/dist_util.py
@@ -0,0 +1,151 @@
+# Copyright (c) 2023 kdha0727
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+"""
+Helpers for distributed training.
+"""
+
+import os
+import functools
+import contextlib
+
+import torch
+import torch.distributed as dist
+from torch.cuda import is_available as _cuda_available
+
+
+RANK = 0
+WORLD_SIZE = 1
+
+
+# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
+# Setup Tools #
+# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
+
+def is_initialized():
+ # if pytorch isn't compiled with c10d, is_initialized is omitted from namespace.
+ # this function wraps
+ """
+ Returns c10d (distributed) runtime is initialized.
+ """
+ return dist.is_available() and getattr(dist, "is_initialized", lambda: False)()
+
+
+def setup_dist(temp_dir, rank, world_size):
+ """
+ Set up a distributed process group.
+ """
+
+ if is_initialized():
+ return True
+
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
+ if os.name == 'nt':
+ init_method = 'file:///' + init_file.replace('\\', '/')
+ dist.init_process_group(
+ backend='gloo', init_method=init_method, rank=rank, world_size=world_size)
+ else:
+ init_method = f'file://{init_file}'
+ dist.init_process_group(
+ backend='nccl', init_method=init_method, rank=rank, world_size=world_size)
+
+ global RANK, WORLD_SIZE
+ RANK = rank
+ WORLD_SIZE = world_size
+
+ torch.cuda.set_device(dev())
+ torch.cuda.empty_cache()
+
+
+# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
+# General Tools #
+# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
+
+@functools.lru_cache(maxsize=None)
+def get_rank(group=None):
+ if group is not None and is_initialized():
+ return dist.get_rank(group=group)
+ return RANK
+
+
+@functools.lru_cache(maxsize=None)
+def get_world_size(group=None):
+ if group is not None and is_initialized():
+ return dist.get_world_size(group=group)
+ return WORLD_SIZE
+
+
+def barrier(*args, **kwargs):
+ if is_initialized():
+ return dist.barrier(*args, **kwargs)
+
+
+@contextlib.contextmanager
+def synchronized_ops():
+ barrier()
+ yield
+ barrier()
+ return
+
+
+@functools.lru_cache(maxsize=None)
+def dev(group=None):
+ """
+ Get the device to use for torch.distributed.
+ """
+ if _cuda_available():
+ return torch.device(get_rank(group))
+ return torch.device("cpu")
+
+
+def load_state_dict(local_or_remote_path, **kwargs):
+ """
+ Load a PyTorch file.
+ """
+ with open(local_or_remote_path, "rb") as f:
+ return torch.load(f, **kwargs)
+
+
+def broadcast(tensor, src=0, group=None, async_op=False):
+ """
+ Synchronize a Tensor across ranks from {src} rank. (default=0)
+ :param tensor: torch.Tensor.
+ :param src: source rank to sync params from. default is 0.
+ :param group:
+ :param async_op:
+ """
+ if not is_initialized():
+ return
+ with torch.no_grad():
+ dist.broadcast(tensor, src, group=group, async_op=async_op)
+
+
+def sync_params(params, src=0, group=None, async_op=False):
+ """
+ Synchronize a sequence of Tensors across ranks from {src} rank. (default=0)
+ :param params: Sequence of torch.Tensor.
+ :param src: source rank to sync params from. default is 0.
+ :param group:
+ :param async_op:
+ """
+ if not is_initialized():
+ return
+ for p in params:
+ broadcast(p, src, group=group, async_op=async_op)
diff --git a/GET3D_NADA/experiments/car_police_example.yaml b/GET3D_NADA/experiments/car_police_example.yaml
new file mode 100644
index 0000000..78f47b5
--- /dev/null
+++ b/GET3D_NADA/experiments/car_police_example.yaml
@@ -0,0 +1,69 @@
+# Global config
+GLOBAL:
+ outdir: 'results/car_police_example' # directory path to save ckpt & results
+ resume_pretrain: '../pretrained_model/shapenet_car.pt' # path to pretrained GET3D model
+ batch: 3 # same as # of samples
+ gpus: 2 # num of gpu
+ sample_1st: 500
+ sample_2nd: -1
+ iter_1st: 1
+ iter_2nd: 3
+ output_interval: 1
+ save_interval: 1
+ vis_samples: 9 # num. of images to be visualized
+
+
+# GET3D config
+GET3D:
+ # (1) basic
+ cfg: 'stylegan2'
+ gamma: 40
+ img_res: 1024
+ data_camera_mode: 'shapenet_car'
+
+ # (2) 3D generator : note that these are used at 'G_kwargs'
+ use_style_mixing: True # Use style mixing for generation during inference
+ one_3d_generator: True # Use improved get3d version
+ dmtet_scale: 1.0 # Scale for the dimension of dmtet
+ n_implicit_layer: 1 # Number of Implicit FC layer for XYZPlaneTex model
+ feat_channel: 16 # Feature channel for TORGB layer
+ mlp_latent_channel: 32 # mlp_latent_channel for XYZPlaneTex network
+ deformation_multiplier: 1 # Multiplier for the predicted deformation
+ tri_plane_resolution: 256 # The resolution for tri plane
+ n_views: 8 # GET3D setting : ! DO NOT FIX !
+ use_tri_plane: True # To use tri plane representation
+ tet_res: 90 # Resolution for teteahedron
+ latent_dim: 512 # Dimension for latent code
+ geometry_type: 'conv3d' # The type of geometry generator
+ render_type: 'neural_render' # Type of renderer we used
+
+ # (3) Misc
+ cbase: 32768 # Capacity multiplier
+ cmax: 512 # Max. feature maps
+ glr: 0.002 # G learning rate
+ mbstd-group: 4 # minibatch std group size
+ c_dim: 0 # class condition
+ img_channels: 3 # RGB image : ! DO NOT FIX !
+
+
+# NADA config
+NADA:
+ # (1) basic
+ lr: 0.002
+ lambda_direction: 1 # strength of directional CLIP loss
+ clip_models: ['ViT-B/32', 'ViT-B/16'] # CLIP image encoder
+ clip_models_weight: [1.0, 1.0] # weight for CLIP image encdoer
+
+ # (2) optional
+ lambda_patch: 0.0
+ lambda_global: 0.0
+ lambda_texture: 0.0
+ lambda_manifold: 0.0
+
+ # (3) settings - text / layer-freezing
+ source_text: 'car'
+ target_text: 'police'
+ auto_layer_iters: 1
+ auto_layer_k: 20
+ auto_layer_batch: 12
+ gradient_clip_threshold: -1
diff --git a/GET3D_NADA/experiments/car_sketch_example.yaml b/GET3D_NADA/experiments/car_sketch_example.yaml
new file mode 100644
index 0000000..18723ff
--- /dev/null
+++ b/GET3D_NADA/experiments/car_sketch_example.yaml
@@ -0,0 +1,70 @@
+# Global config
+GLOBAL:
+ outdir: 'results/car_sketch_example' # directory path to save ckpt & results
+ resume_pretrain: '../pretrained_model/shapenet_car.pt' # path to pretrained GET3D model
+ batch: 3 # same as # of samples
+ gpus: 2 # num of gpu
+ sample_1st: 500
+ sample_2nd: -1
+ iter_1st: 1
+ iter_2nd: 30
+ output_interval: 1
+ save_interval: 1
+ vis_samples: 16 # num. of images to be visualized
+
+
+# GET3D config
+GET3D:
+ # (1) basic
+ cfg: 'stylegan2'
+ gamma: 40
+ img_res: 1024
+ data_camera_mode: 'shapenet_car'
+
+ # (2) 3D generator : note that these are used at 'G_kwargs'
+ use_style_mixing: True # Use style mixing for generation during inference
+ one_3d_generator: True # Use improved get3d version
+ dmtet_scale: 1.0 # Scale for the dimension of dmtet
+ n_implicit_layer: 1 # Number of Implicit FC layer for XYZPlaneTex model
+ feat_channel: 16 # Feature channel for TORGB layer
+ mlp_latent_channel: 32 # mlp_latent_channel for XYZPlaneTex network
+ deformation_multiplier: 1 # Multiplier for the predicted deformation
+ tri_plane_resolution: 256 # The resolution for tri plane
+ n_views: 9 # GET3D setting : ! DO NOT FIX !
+ use_tri_plane: True # To use tri plane representation
+ tet_res: 90 # Resolution for teteahedron
+ latent_dim: 512 # Dimension for latent code
+ geometry_type: 'conv3d' # The type of geometry generator
+ render_type: 'neural_render' # Type of renderer we used
+
+ # (3) Misc
+ cbase: 32768 # Capacity multiplier
+ cmax: 512 # Max. feature maps
+ glr: 0.002 # G learning rate
+ mbstd-group: 4 # minibatch std group size
+ c_dim: 0 # class condition
+ img_channels: 3 # RGB image : ! DO NOT FIX !
+
+
+# NADA config
+NADA:
+ # (1) basic
+ lr: 0.002
+ lambda_direction: 1 # strength of directional CLIP loss
+ clip_models: ['ViT-B/32', 'ViT-B/16'] # CLIP image encoder
+ clip_models_weight: [1.0, 1.0] # weight for CLIP image encdoer
+
+ # (2) optional
+ lambda_patch: 0.0
+ lambda_global: 0.0
+ lambda_texture: 0.0
+ lambda_manifold: 0.0
+
+ # (3) settings - text / layer-freezing
+ source_text: 'car'
+ target_text: 'sketch'
+ auto_layer_iters: 1
+ auto_layer_k: 30
+ auto_layer_batch: 12
+ gradient_clip_threshold: -1
+
diff --git a/GET3D_NADA/experiments/motorbike_tiger_example.yaml b/GET3D_NADA/experiments/motorbike_tiger_example.yaml
new file mode 100644
index 0000000..402f519
--- /dev/null
+++ b/GET3D_NADA/experiments/motorbike_tiger_example.yaml
@@ -0,0 +1,69 @@
+# Global config
+GLOBAL:
+ outdir: 'results/motorbike_tiger_example' # directory path to save ckpt & results
+ resume_pretrain: '../pretrained_model/shapenet_motorbike.pt' # path to pretrained GET3D model
+ batch: 3 # same as # of samples
+ gpus: 2 # num of gpu
+ sample_1st: 500
+ sample_2nd: -1
+ iter_1st: 1
+ iter_2nd: 3
+ output_interval: 1
+ save_interval: 1
+ vis_samples: 9 # num. of images to be visualized
+
+
+# GET3D config
+GET3D:
+ # (1) basic
+ cfg: 'stylegan2'
+ gamma: 40
+ img_res: 1024
+ data_camera_mode: 'shapenet_car'
+
+ # (2) 3D generator : note that these are used at 'G_kwargs'
+ use_style_mixing: True # Use style mixing for generation during inference
+ one_3d_generator: True # Use improved get3d version
+ dmtet_scale: 1.0 # Scale for the dimension of dmtet
+ n_implicit_layer: 1 # Number of Implicit FC layer for XYZPlaneTex model
+ feat_channel: 16 # Feature channel for TORGB layer
+ mlp_latent_channel: 32 # mlp_latent_channel for XYZPlaneTex network
+ deformation_multiplier: 1 # Multiplier for the predicted deformation
+ tri_plane_resolution: 256 # The resolution for tri plane
+ n_views: 8 # GET3D setting : ! DO NOT FIX !
+ use_tri_plane: True # To use tri plane representation
+ tet_res: 90 # Resolution for teteahedron
+ latent_dim: 512 # Dimension for latent code
+ geometry_type: 'conv3d' # The type of geometry generator
+ render_type: 'neural_render' # Type of renderer we used
+
+ # (3) Misc
+ cbase: 32768 # Capacity multiplier
+ cmax: 512 # Max. feature maps
+ glr: 0.002 # G learning rate
+ mbstd-group: 4 # minibatch std group size
+ c_dim: 0 # class condition
+ img_channels: 3 # RGB image : ! DO NOT FIX !
+
+
+# NADA config
+NADA:
+ # (1) basic
+ lr: 0.002
+ lambda_direction: 1 # strength of directional CLIP loss
+ clip_models: ['ViT-B/32', 'ViT-B/16'] # CLIP image encoder
+ clip_models_weight: [1.0, 1.0] # weight for CLIP image encdoer
+
+ # (2) optional
+ lambda_patch: 0.0
+ lambda_global: 0.0
+ lambda_texture: 0.0
+ lambda_manifold: 0.0
+
+ # (3) settings - text / layer-freezing
+ source_text: 'motorbike'
+ target_text: 'tiger'
+ auto_layer_iters: 1
+ auto_layer_k: 20
+ auto_layer_batch: 12
+ gradient_clip_threshold: -1
diff --git a/GET3D_NADA/experiments/shoes_mossy_example.yaml b/GET3D_NADA/experiments/shoes_mossy_example.yaml
new file mode 100644
index 0000000..e4e78bc
--- /dev/null
+++ b/GET3D_NADA/experiments/shoes_mossy_example.yaml
@@ -0,0 +1,68 @@
+# Global config
+GLOBAL:
+ outdir: 'results/shoes_mossy_example' # directory path to save ckpt & results
+ resume_pretrain: '../pretrained_model/shoes-00204.pt' # path to pretrained GET3D model
+ batch: 3 # same as # of samples
+ gpus: 2 # num of gpu
+ sample_1st: 500
+ sample_2nd: -1
+ iter_1st: 1
+ iter_2nd: 3
+ output_interval: 1
+ save_interval: 1
+ vis_samples: 9 # num. of images to be visualized
+
+
+# GET3D config
+GET3D:
+ # (1) basic
+ cfg: 'stylegan2'
+ gamma: 40
+ img_res: 1024
+ data_camera_mode: 'shapenet_car'
+
+ # (2) 3D generator : note that these are used at 'G_kwargs'
+ use_style_mixing: True # Use style mixing for generation during inference
+ one_3d_generator: True # Use improved get3d version
+ dmtet_scale: 1.0 # Scale for the dimension of dmtet
+ n_implicit_layer: 1 # Number of Implicit FC layer for XYZPlaneTex model
+ feat_channel: 16 # Feature channel for TORGB layer
+ mlp_latent_channel: 32 # mlp_latent_channel for XYZPlaneTex network
+ deformation_multiplier: 1 # Multiplier for the predicted deformation
+ tri_plane_resolution: 256 # The resolution for tri plane
+ n_views: 8 # GET3D setting : ! DO NOT FIX !
+ use_tri_plane: True # To use tri plane representation
+ tet_res: 90 # Resolution for teteahedron
+ latent_dim: 512 # Dimension for latent code
+ geometry_type: 'conv3d' # The type of geometry generator
+ render_type: 'neural_render' # Type of renderer we used
+
+ # (3) Misc
+ cbase: 32768 # Capacity multiplier
+ cmax: 512 # Max. feature maps
+ glr: 0.002 # G learning rate
+ mbstd-group: 4 # minibatch std group size
+ c_dim: 0 # class condition
+ img_channels: 3 # RGB image : ! DO NOT FIX !
+
+# NADA config
+NADA:
+ # (1) basic
+ lr: 0.002
+ lambda_direction: 1 # strength of directional CLIP loss
+ clip_models: ['ViT-B/32', 'ViT-B/16'] # CLIP image encoder
+ clip_models_weight: [1.0, 1.0] # weight for CLIP image encdoer
+
+ # (2) optional
+ lambda_patch: 0.0
+ lambda_global: 0.0
+ lambda_texture: 0.0
+ lambda_manifold: 0.0
+
+ # (3) settings - text / layer-freezing
+ source_text: 'shoes'
+ target_text: 'mossy'
+ auto_layer_iters: 1
+ auto_layer_k: 30
+ auto_layer_batch: 12
+ gradient_clip_threshold: -1
diff --git a/GET3D_NADA/extra_requirements.txt b/GET3D_NADA/extra_requirements.txt
new file mode 100644
index 0000000..c5dcda2
--- /dev/null
+++ b/GET3D_NADA/extra_requirements.txt
@@ -0,0 +1,3 @@
+pyyaml
+ftfy
+git+https://github.com/openai/CLIP.git
diff --git a/GET3D_NADA/functional.py b/GET3D_NADA/functional.py
new file mode 100644
index 0000000..9c242c7
--- /dev/null
+++ b/GET3D_NADA/functional.py
@@ -0,0 +1,338 @@
+"""
+Methods (functions) for GET3D generator, required for NADA training and inference
+"""
+import torch
+import nvdiffrast.torch as dr
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from training.networks_get3d import DMTETSynthesisNetwork
+ from training.networks_get3d import GeneratorDMTETMesh
+
+
+# Class GeneratorDMTETMesh
+def get_all_generator_layers_dict(self: "GeneratorDMTETMesh"):
+
+ layer_idx_geo = {}
+ layer_idx_tex = {}
+
+ tri_plane_blocks = self.synthesis.generator.tri_plane_synthesis.children()
+
+ idx_geo = 0
+ idx_tex = 0
+
+ # triplane
+ for block in tri_plane_blocks:
+ if hasattr(block, 'conv0'):
+ layer_idx_geo[idx_geo] = f'b{block.resolution}.conv0'
+ idx_geo += 1
+ if hasattr(block, 'conv1'):
+ layer_idx_geo[idx_geo] = f'b{block.resolution}.conv1'
+ idx_geo += 1
+ if hasattr(block, 'togeo'):
+ layer_idx_geo[idx_geo] = f'b{block.resolution}.togeo'
+ idx_geo += 1
+ if hasattr(block, 'totex'):
+ layer_idx_tex[idx_tex] = f'b{block.resolution}.totex'
+ idx_tex += 1
+
+ # mlp_synthesis
+ # note that last number = ModuleList index
+ layer_idx_tex[idx_tex] = 'mlp_synthesis_tex.0'
+ idx_tex += 1
+ layer_idx_tex[idx_tex] = 'mlp_synthesis_tex.1'
+
+ layer_idx_geo[idx_geo] = 'mlp_synthesis_geo.0'
+ idx_geo += 1
+ layer_idx_geo[idx_geo] = 'mlp_synthesis_geo.1'
+
+ return layer_idx_tex, layer_idx_geo
+
+
+# Class GeneratorDMTETMesh
+def freeze_generator_layers(self: "GeneratorDMTETMesh", layer_tex_dict=None, layer_geo_dict=None):
+ assert layer_geo_dict is None and layer_tex_dict is None
+ self.synthesis.requires_grad_(False) # all freeze
+
+
+# Class GeneratorDMTETMesh
+def unfreeze_generator_layers(self: "GeneratorDMTETMesh", topk_idx_tex: list, topk_idx_geo: list):
+ """
+ args
+ topk_idx_tex : chosen layers - geo
+ topk_idx_geo : chosen layers - tex
+ layer_geo_dict , layer_tex_dict : result of get_all_generator_layers()
+ """
+ if not topk_idx_tex and not topk_idx_geo:
+ self.synthesis.generator.tri_plane_synthesis.requires_grad_(True)
+ return # all unfreeze
+
+ layer_tex_dict, layer_geo_dict = get_all_generator_layers_dict(self)
+
+ for idx_tex in topk_idx_tex:
+ if idx_tex >= 7:
+ # mlp_synthesis_tex
+ mlp_name, layer_idx = layer_tex_dict[idx_tex].split('.')
+ layer_tex = getattr(self.synthesis.generator.mlp_synthesis_tex, 'layers')[int(layer_idx)]
+ layer_tex.requires_grad_(True)
+ self.synthesis.generator.mlp_synthesis_tex.layers[int(layer_idx)] = layer_tex
+
+ else:
+ # Texture TriPlane
+ block_name, layer_name = layer_tex_dict[idx_tex].split('.')
+ block = getattr(self.synthesis.generator.tri_plane_synthesis, block_name)
+ getattr(block, layer_name).requires_grad_(True)
+ setattr(self.synthesis.generator.tri_plane_synthesis, block_name, block)
+
+ for idx_geo in topk_idx_geo:
+ if idx_geo >= 20:
+ # mlp_synthesis_sdf
+ mlp_name, layer_idx = layer_geo_dict[idx_geo].split('.')
+ layer_sdf = getattr(self.synthesis.generator.mlp_synthesis_sdf, 'layers')[int(layer_idx)]
+ layer_sdf.requires_grad_(True)
+ self.synthesis.generator.mlp_synthesis_sdf.layers[int(layer_idx)] = layer_sdf
+ # mlp_synthesis_def
+ layer_def = getattr(self.synthesis.generator.mlp_synthesis_def, 'layers')[int(layer_idx)]
+ layer_def.requires_grad_(True)
+ self.synthesis.generator.mlp_synthesis_def.layers[int(layer_idx)] = layer_def
+
+ else:
+ # Geometry TriPlane
+ block_name, layer_name = layer_geo_dict[idx_geo].split('.')
+ block = getattr(self.synthesis.generator.tri_plane_synthesis, block_name)
+ getattr(block, layer_name).requires_grad_(True)
+ setattr(self.synthesis.generator.tri_plane_synthesis, block_name, block)
+
+
+# Class :DMTETSynthesisNetwork
+def generate_nada_mode_synthesis(
+ self: "DMTETSynthesisNetwork",
+ ws,
+ ws_geo,
+ camera=None,
+ texture_resolution=2048,
+ mode='nada',
+ **block_kwargs
+):
+ """
+ mode='thumbnail' : To make thumbnail
+ mode='layer' : To support layer-freezing
+ mode='nada' : To support 1 latent - N views rendering
+ """
+
+ # ------------------- generate ------------------- #
+
+ # (1) Generate 3D mesh first
+ # NOTE :
+ # this code is shared by 'def generate' and 'def extract_3d_mesh'
+ if self.one_3d_generator:
+ sdf_feature, tex_feature = self.generator.get_feature(
+ ws[:, :self.generator.tri_plane_synthesis.num_ws_tex],
+ ws_geo[:, :self.generator.tri_plane_synthesis.num_ws_geo])
+ ws = ws[:, self.generator.tri_plane_synthesis.num_ws_tex:]
+ ws_geo = ws_geo[:, self.generator.tri_plane_synthesis.num_ws_geo:]
+ mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(ws_geo, sdf_feature)
+ else:
+ mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(ws_geo)
+
+ ws_tex = ws
+
+ # (2) Generate random camera
+ with torch.no_grad():
+ if camera is None:
+ # if mode == "nada" or mode == "layer": #js
+ if mode == 'nada':
+ campos, cam_mv, rotation_angle, elevation_angle, sample_r = self.generate_random_camera(
+ ws_tex.shape[0], n_views=self.n_views)
+ gen_camera = (campos, cam_mv, sample_r, rotation_angle, elevation_angle)
+ run_n_view = self.n_views
+ else:
+ campos, cam_mv, rotation_angle, elevation_angle, sample_r = self.generate_random_camera(
+ ws_tex.shape[0], n_views=1)
+ gen_camera = (campos, cam_mv, sample_r, rotation_angle, elevation_angle)
+ run_n_view = 1
+ else:
+ if isinstance(camera, tuple):
+ cam_mv = camera[0]
+ campos = camera[1]
+ else:
+ cam_mv = camera
+ campos = None
+ gen_camera = camera
+ run_n_view = cam_mv.shape[1]
+
+ # NOTE
+ # tex_pos: Position we want to query the texture field || List[(1,1024, 1024,3) * Batch]
+ # tex_hard_mask = 2D silhoueete of the rendered image || Tensor(Batch, 1024, 1024, 1)
+
+ if mode == 'nada':
+
+ antilias_mask = []
+ tex_pos = []
+ tex_hard_mask = []
+ return_value = {'tex_pos': []}
+
+ for idx in range(self.n_views):
+ cam = cam_mv[:, idx, :, :].unsqueeze(1)
+ antilias_mask_, hard_mask_, return_value_ = self.render_mesh(mesh_v, mesh_f, cam)
+ antilias_mask.append(antilias_mask_)
+ tex_hard_mask.append(hard_mask_)
+
+ for pos in return_value_['tex_pos']:
+ return_value['tex_pos'].append(pos)
+
+ antilias_mask = torch.cat(antilias_mask, dim=0) # (B*n_view, 1024, 1024, 1)
+ tex_hard_mask = torch.cat(tex_hard_mask, dim=0) # (B*n_view, 1024, 1024, 3)
+ tex_pos = return_value['tex_pos']
+
+ ws_tex = ws_tex.repeat(self.n_views, 1, 1)
+ ws_geo = ws_geo.repeat(self.n_views, 1, 1)
+ tex_feature = tex_feature.repeat(self.n_views, 1, 1, 1)
+
+ else:
+ # (3) Render the mesh into 2D image (get 3d position of each image plane)
+ antilias_mask, hard_mask, return_value = self.render_mesh(mesh_v, mesh_f, cam_mv)
+
+ tex_pos = return_value['tex_pos']
+ tex_hard_mask = hard_mask
+
+ tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos]
+ tex_hard_mask = torch.cat(
+ [torch.cat(
+ [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1]
+ for i_view in range(run_n_view)], dim=2)
+ for i in range(ws_tex.shape[0])], dim=0)
+
+ # (4) Querying the texture field to predict the texture feature for each pixel on the image
+ if self.one_3d_generator:
+ tex_feat = self.get_texture_prediction(ws_tex, tex_pos, ws_geo.detach(), tex_hard_mask, tex_feature)
+ else:
+ tex_feat = self.get_texture_prediction(
+ ws_tex, tex_pos, ws_geo.detach(), tex_hard_mask)
+ background_feature = torch.zeros_like(tex_feat)
+
+ # (5) Merge them together
+ img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
+
+ # NOTE : debug -> no need to execute (6)
+ # (6) We should split it back to the original image shape
+
+ ws_list = [ws_tex[i].unsqueeze(dim=0).expand(return_value['tex_pos'][i].shape[0], -1, -1) for i in
+ range(len(return_value['tex_pos']))]
+ ws = torch.cat(ws_list, dim=0).contiguous()
+
+ # (7) Predict the RGB color for each pixel (self.to_rgb is 1x1 convolution)
+ if self.feat_channel > 3:
+ network_out = self.to_rgb(img_feat.permute(0, 3, 1, 2), ws[:, -1])
+ else:
+ network_out = img_feat.permute(0, 3, 1, 2)
+
+ img = network_out
+ img_buffers_viz = None
+
+ if self.render_type == 'neural_render':
+ img = img[:, :3]
+ else:
+ raise NotImplementedError
+
+ img = torch.cat([img, antilias_mask.permute(0, 3, 1, 2)], dim=1)
+
+ return_generate = [img, antilias_mask]
+
+ if mode == 'layer' or mode == 'nada':
+ return return_generate[0], None
+
+ elif mode == 'thumbnail':
+ # ------------------- extract_3d_shape ------------------- #
+
+ del tex_hard_mask
+ del tex_feat
+
+ # (8) Use x-atlas to get uv mapping for the mesh
+ from training.extract_texture_map import xatlas_uvmap
+ all_uvs = []
+ all_mesh_tex_idx = []
+ all_gb_pose = []
+ all_uv_mask = []
+ if self.dmtet_geometry.renderer.ctx is None:
+ self.dmtet_geometry.renderer.ctx = dr.RasterizeGLContext(device=self.device)
+ for v, f in zip(mesh_v, mesh_f):
+ uvs, mesh_tex_idx, gb_pos, mask = xatlas_uvmap(
+ self.dmtet_geometry.renderer.ctx, v, f, resolution=texture_resolution)
+ all_uvs.append(uvs)
+ all_mesh_tex_idx.append(mesh_tex_idx)
+ all_gb_pose.append(gb_pos)
+ all_uv_mask.append(mask)
+
+ tex_hard_mask = torch.cat(all_uv_mask, dim=0).float()
+
+ # (9) Query the texture field to get the RGB color for texture map
+ all_network_output = []
+ for _ws, _all_gb_pose, _ws_geo, _tex_hard_mask in zip(ws, all_gb_pose, ws_geo, tex_hard_mask):
+ if self.one_3d_generator:
+ tex_feat = self.get_texture_prediction(
+ _ws.unsqueeze(dim=0), [_all_gb_pose],
+ _ws_geo.unsqueeze(dim=0).detach(),
+ _tex_hard_mask.unsqueeze(dim=0),
+ tex_feature)
+ else:
+ tex_feat = self.get_texture_prediction(
+ _ws.unsqueeze(dim=0), [_all_gb_pose],
+ _ws_geo.unsqueeze(dim=0).detach(),
+ _tex_hard_mask.unsqueeze(dim=0))
+ background_feature = torch.zeros_like(tex_feat)
+ # Merge them together
+ img_feat = tex_feat * _tex_hard_mask.unsqueeze(dim=0) + background_feature * (
+ 1 - _tex_hard_mask.unsqueeze(dim=0))
+ network_out = self.to_rgb(img_feat.permute(0, 3, 1, 2), _ws.unsqueeze(dim=0)[:, -1])
+ all_network_output.append(network_out)
+ network_out = torch.cat(all_network_output, dim=0)
+
+ return_extract_3d_mesh = [mesh_v, mesh_f, all_uvs, all_mesh_tex_idx, network_out]
+
+ return return_generate, return_extract_3d_mesh
+
+ else:
+ raise NotImplementedError
+
+
+# Class: GeneratorDMTETMesh
+def generate_nada_mode(
+ self: "GeneratorDMTETMesh",
+ geo_z, tex_z, c=0, truncation_psi=1, truncation_cutoff=None,
+ update_emas=False, use_mapping=False, # -> generate_3d_mesh
+ camera=None, # -> generate_3d
+ mode='thumbnail',
+ **synthesis_kwargs):
+ """
+ Description
+ mode='thumbnail' : To make thumbnail
+ mode='layer' : To support layer-freezing
+ mode='nada' : To support 1 latent - N views rendering
+
+ Note :
+ this function don't take below as input args
+ 1. use_style_mixing
+ 2. generate_no_light
+ 3. with_texture
+ , since they are redundant.
+
+ Return :
+ return_generate_3d = [rendered RGB Image, rendered 2D Silhouette image]
+ return_generate_3d_mesh = [mesh_v, mesh_f, all_uvs, all_mesh_tex_idx, texture map]
+ """
+
+ if use_mapping or mode == 'thumbnail':
+ ws = self.mapping(
+ tex_z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff,
+ update_emas=update_emas)
+ ws_geo = self.mapping_geo(
+ geo_z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff,
+ update_emas=update_emas)
+ else:
+ ws = tex_z
+ ws_geo = geo_z
+
+ return generate_nada_mode_synthesis(
+ self.synthesis,
+ ws, ws_geo, camera=camera, mode=mode, **synthesis_kwargs
+ ) # custom inference code.
diff --git a/GET3D_NADA/model_engine.py b/GET3D_NADA/model_engine.py
new file mode 100644
index 0000000..1ab5694
--- /dev/null
+++ b/GET3D_NADA/model_engine.py
@@ -0,0 +1,171 @@
+"""
+Config parser class for efficient PM
+"""
+import os
+import sys
+import copy
+import yaml
+import torch
+from contextlib import contextmanager
+
+import dist_util
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from training.networks_get3d import GeneratorDMTETMesh
+
+GET3D_ROOT = None
+
+
+class Engine(object):
+ """Config parser class for efficient management"""
+ rank: int
+ config: dict
+ device: torch.device
+ global_kwargs: dict
+ G_kwargs: dict
+ clip_kwargs: dict
+
+ @classmethod
+ def parse_engine_like(cls, engine_like):
+ if isinstance(engine_like, cls): # Engine
+ return engine_like
+ elif isinstance(engine_like, dict): # config dict
+ return cls(engine_like)
+ elif isinstance(engine_like, str) or hasattr(engine_like, '__fspath__'): # path
+ with open(engine_like, 'r') as fp:
+ return cls(yaml.safe_load(fp))
+ elif hasattr(engine_like, 'read'): # file-like
+ return cls(yaml.safe_load(engine_like))
+ raise TypeError
+
+ def __init__(self, config: dict, rank: "int|None" = None):
+ self.rank = rank
+ self.config = config
+ self.parse()
+
+ def parse(self):
+ if self.rank is None:
+ self.rank = dist_util.get_rank()
+ self.device = torch.device('cuda', self.rank)
+
+ # setting : global configuration
+ self.global_kwargs = dnnlib.EasyDict(self.config['GLOBAL'])
+
+ # ref) get3d : train_3d.py ln251 - ln320
+ # setting : GET3D configuration
+ opts = dnnlib.EasyDict(self.config['GET3D'])
+ # global
+ G_kwargs = self.G_kwargs = dnnlib.EasyDict()
+ G_kwargs.device = self.device
+ G_kwargs.class_name = 'training.networks_get3d.GeneratorDMTETMesh'
+ G_kwargs.img_resolution = opts.img_res # // reformed
+ G_kwargs.img_channels = opts.img_channels # // reformed
+ # mapping network
+ G_kwargs.z_dim = opts.latent_dim
+ G_kwargs.w_dim = opts.latent_dim
+ G_kwargs.c_dim = opts.c_dim # 0(=None) # NOTE : This can be used for class conditioning ... // reformed
+ G_kwargs.mapping_kwargs = dnnlib.EasyDict()
+ G_kwargs.mapping_kwargs.num_layers = 8
+ # stylegan2 + tri-plane
+ G_kwargs.use_style_mixing = opts.use_style_mixing
+ G_kwargs.one_3d_generator = opts.one_3d_generator
+ G_kwargs.dmtet_scale = opts.dmtet_scale
+ G_kwargs.n_implicit_layer = opts.n_implicit_layer
+ G_kwargs.feat_channel = opts.feat_channel
+ G_kwargs.mlp_latent_channel = opts.mlp_latent_channel
+ G_kwargs.deformation_multiplier = opts.deformation_multiplier
+ G_kwargs.tri_plane_resolution = opts.tri_plane_resolution
+ G_kwargs.n_views = opts.n_views
+ G_kwargs.use_tri_plane = opts.use_tri_plane
+ G_kwargs.tet_res = opts.tet_res
+ # G_kwargs.tet_path = '../data/tets'
+ # neural renderer
+ G_kwargs.render_type = opts.render_type
+ G_kwargs.data_camera_mode = opts.data_camera_mode
+ # misc
+ G_kwargs.fused_modconv_default = 'inference_only'
+
+ # setting : NADA configuration
+ clip_kwargs = self.clip_kwargs = dnnlib.EasyDict(self.config['NADA'])
+ clip_kwargs.device = self.device
+
+ def build_get3d_pair(self):
+ with at_working_directory(GET3D_ROOT):
+ G_ema: "GeneratorDMTETMesh" = dnnlib.util.construct_class_by_name(**self.G_kwargs)
+ G_ema.to(self.device).train().requires_grad_(False)
+
+ assert self.global_kwargs['resume_pretrain'] != '', "ASSERTION : Specify pretrained GET3D model"
+ if self.rank == 0:
+ model_state_dict = torch.load(
+ self.global_kwargs['resume_pretrain'],
+ map_location=self.device
+ )
+ G_ema.load_state_dict(model_state_dict['G_ema'], strict=True)
+ dist_util.sync_params(G_ema.parameters(), src=0)
+ dist_util.sync_params(G_ema.buffers(), src=0)
+
+ G_ema_frozen: "GeneratorDMTETMesh" = copy.deepcopy(G_ema).eval()
+ return G_ema, G_ema_frozen
+
+
+@contextmanager
+def at_working_directory(work_dir):
+ """Context manager for changing working directory."""
+ prev = os.getcwd()
+ try:
+ os.chdir(work_dir)
+ yield
+ finally:
+ os.chdir(prev)
+
+
+def find_get3d():
+ """
+ This function makes dynamic import of GET3D modules available.
+ Officially supported ways:
+ 1. Locate this module's directory in GET3D directory. (recommended)
+ 2. Locate GET3D via submodule, by `git submodule sync && git submodule update --init --recursive`.
+ 3. Set GET3D directory via environment variable `GET3D_ROOT`.
+ 4. Manually specify GET3D directory in this file, by variable `GET3D_ROOT` (line 21).
+ """
+ global GET3D_ROOT
+ # 1. check if GET3D_ROOT is already set and in sys.path
+ if GET3D_ROOT is not None and GET3D_ROOT in sys.path:
+ return True
+ # 2. check if GET3D modules are already imported and __file__ attribute is available
+ try:
+ import training.networks_get3d
+ except ImportError:
+ pass
+ if hasattr(sys.modules.get('training.networks_get3d', None), '__file__'):
+ GET3D_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(sys.modules['training.networks_get3d'].__file__)))
+ return True
+ # 3. check if GET3D_ROOT is specified via environment variable, or try to guess
+ import importlib
+ base = os.path.dirname(os.path.abspath(__file__))
+ candidates = [
+ GET3D_ROOT,
+ os.getenv('GET3D_ROOT', None),
+ os.path.dirname(base),
+ os.path.join(base, 'GET3D'),
+ ]
+ for candidate in candidates: # Try each candidate path in order.
+ if candidate is not None and os.path.isdir(os.path.join(candidate, 'training')):
+ try:
+ sys.path.insert(0, candidate)
+ importlib.import_module('training.networks_get3d')
+ GET3D_ROOT = candidate
+ break
+ except ImportError:
+ sys.path.pop(0)
+ if GET3D_ROOT is None: # Fail if all candidates failed.
+ raise ImportError(
+ 'Failed to find GET3D root directory. '
+ 'Please specify the location of GET3D via GET3D_ROOT environment variable.'
+ )
+ else:
+ return True
+
+
+if find_get3d():
+ import dnnlib
diff --git a/GET3D_NADA/nada.py b/GET3D_NADA/nada.py
new file mode 100644
index 0000000..f7ede3c
--- /dev/null
+++ b/GET3D_NADA/nada.py
@@ -0,0 +1,151 @@
+"""
+Component Description
+ - YAIverseGAN (NADA) module
+"""
+
+import torch
+from clip_loss import CLIPLoss
+from functional import generate_nada_mode, freeze_generator_layers, unfreeze_generator_layers
+from model_engine import Engine
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from typing import Union, Dict, Any
+
+
+class GeneratorGET3DNADA(torch.nn.Module):
+
+ def __init__(self, engine: "Union[Engine, Dict[str, Any], str]"):
+ super(GeneratorGET3DNADA, self).__init__()
+
+ self.engine = Engine.parse_engine_like(engine)
+ self.device = self.engine.device
+
+ self.generator_trainable, self.generator_frozen = self.engine.build_get3d_pair()
+
+ # clip model & loss
+ clip_kwargs = self.engine.clip_kwargs
+ self.clip_loss_models = torch.nn.ModuleDict({
+ model_name: CLIPLoss(
+ self.device,
+ lambda_direction=clip_kwargs['lambda_direction'],
+ lambda_patch=clip_kwargs['lambda_patch'],
+ lambda_global=clip_kwargs['lambda_global'],
+ lambda_texture=clip_kwargs['lambda_texture'],
+ lambda_manifold=clip_kwargs['lambda_manifold'],
+ clip_model=model_name
+ ) for model_name in clip_kwargs['clip_models']
+ })
+
+ self.clip_model_weights = {
+ model_name: weight for model_name, weight in
+ zip(clip_kwargs['clip_models'], clip_kwargs['clip_models_weight'])
+ }
+
+ # text
+ self.source_text = clip_kwargs['source_text']
+ self.target_text = clip_kwargs['target_text']
+
+ # for layer freezing process
+ self.auto_layer_k = clip_kwargs['auto_layer_k']
+ self.auto_layer_iters = clip_kwargs['auto_layer_iters']
+ self.auto_layer_batch = clip_kwargs['auto_layer_batch']
+
+ self.to(self.device)
+
+ def get_loop_settings(self):
+ g = self.engine.global_kwargs
+ return (
+ self.device, g.outdir, g.batch, g.vis_samples,
+ g.sample_1st, g.sample_2nd, g.iter_1st, g.iter_2nd,
+ self.engine.clip_kwargs.lr,
+ g.output_interval, g.save_interval,
+ self.engine.clip_kwargs.gradient_clip_threshold
+ )
+
+ def determine_opt_layers(self):
+ """
+ original code : return chosen layers : List[nn.Modules, nn.Modules, ...]
+ this code : return chosen layers idx : List[int, int, ...], List[int, int, ...]
+ * note that this returns two list for tex. and geo.
+ """
+ z_dim = 512
+ c_dim = self.engine.G_kwargs['c_dim']
+ sample_z_tex = torch.randn(self.auto_layer_batch, z_dim, device=self.device)
+ sample_z_geo = torch.randn(self.auto_layer_batch, z_dim, device=self.device)
+
+ with torch.no_grad():
+ initial_w_tex_codes = self.generator_frozen.mapping(sample_z_tex, c_dim) # (B, 9, 512)
+ initial_w_geo_codes = self.generator_frozen.mapping_geo(sample_z_geo, c_dim) # (B, 22, 512)
+
+ w_tex_codes = torch.Tensor(initial_w_tex_codes.cpu().detach().numpy()).to(self.device)
+ w_geo_codes = torch.Tensor(initial_w_geo_codes.cpu().detach().numpy()).to(self.device)
+
+ w_tex_codes.requires_grad = True
+ w_geo_codes.requires_grad = True
+
+ w_optim = torch.optim.SGD([w_tex_codes, w_geo_codes], lr=0.01)
+
+ for _ in range(self.auto_layer_iters):
+ generated_from_w, _ = generate_nada_mode(self.generator_trainable, tex_z=w_tex_codes, geo_z=w_geo_codes,
+ mode='layer') # (B, C, H, W)
+ generated_from_w = generated_from_w[:, :-1, :, :
+ ] # [RGB image, Silhouette] (B,4,H,W) -> [RGB image] (B,3,H,W)
+ w_loss = [self.clip_model_weights[model_name] * self.clip_loss_models[model_name].global_clip_loss(
+ generated_from_w, self.target_text) for model_name in self.clip_model_weights.keys()]
+ w_loss = torch.sum(torch.stack(w_loss))
+
+ w_optim.zero_grad()
+ w_loss.backward()
+ w_optim.step()
+
+ layer_tex_weights = torch.abs(w_tex_codes - initial_w_tex_codes).mean(dim=-1).mean(dim=0)
+ layer_geo_weights = torch.abs(w_geo_codes - initial_w_geo_codes).mean(dim=-1).mean(dim=0)
+
+ cutoff = len(layer_tex_weights)
+
+ chosen_layers_idx = torch.topk(torch.cat([layer_tex_weights, layer_geo_weights], dim=0), self.auto_layer_k)[
+ 1].cpu().numpy().tolist()
+ chosen_layer_idx_tex = []
+ chosen_layer_idx_geo = []
+ for idx in chosen_layers_idx:
+ if idx >= cutoff:
+ chosen_layer_idx_geo.append(idx - cutoff)
+ else:
+ chosen_layer_idx_tex.append(idx)
+
+ return chosen_layer_idx_tex, chosen_layer_idx_geo
+
+ def forward(self, tex_z, geo_z): # modified for GET3D
+ c_dim = self.engine.G_kwargs['c_dim']
+ batch = tex_z.shape[0]
+
+ if self.training and self.auto_layer_iters > 0:
+ unfreeze_generator_layers(self.generator_trainable, [], [])
+ topk_idx_tex, topk_idx_geo = self.determine_opt_layers()
+ freeze_generator_layers(self.generator_trainable)
+ unfreeze_generator_layers(self.generator_trainable, topk_idx_tex, topk_idx_geo)
+
+ w_geo = self.generator_frozen.mapping_geo(geo_z, c_dim)
+ w_tex = self.generator_frozen.mapping(tex_z, c_dim)
+
+ with torch.no_grad():
+ frozen_img, _ = generate_nada_mode(self.generator_frozen, tex_z=w_tex, geo_z=w_geo, c=c_dim, mode='nada')
+
+ trainable_img, _ = generate_nada_mode(self.generator_trainable, tex_z=w_tex, geo_z=w_geo, c=c_dim, mode='nada')
+
+ input_dict = {
+ "src_img": frozen_img[:, :-1],
+ "target_img": trainable_img[:, :-1],
+ "source_class": self.source_text,
+ "target_class": self.target_text
+ }
+
+ clip_loss = torch.sum(
+ torch.stack([
+ self.clip_model_weights[model_name] * self.clip_loss_models[model_name](**input_dict)
+ for model_name in self.clip_model_weights.keys()
+ ]), dim=0
+ )
+
+ clip_loss = torch.mean(clip_loss.reshape(-1, batch), dim=0)
+ return [frozen_img, trainable_img], clip_loss
diff --git a/GET3D_NADA/text_templates.py b/GET3D_NADA/text_templates.py
new file mode 100644
index 0000000..f36205a
--- /dev/null
+++ b/GET3D_NADA/text_templates.py
@@ -0,0 +1,129 @@
+imagenet_templates = [
+ 'a bad photo of a {}.',
+ 'a sculpture of a {}.',
+ 'a photo of the hard to see {}.',
+ 'a low resolution photo of the {}.',
+ 'a rendering of a {}.',
+ 'graffiti of a {}.',
+ 'a bad photo of the {}.',
+ 'a cropped photo of the {}.',
+ 'a tattoo of a {}.',
+ 'the embroidered {}.',
+ 'a photo of a hard to see {}.',
+ 'a bright photo of a {}.',
+ 'a photo of a clean {}.',
+ 'a photo of a dirty {}.',
+ 'a dark photo of the {}.',
+ 'a drawing of a {}.',
+ 'a photo of my {}.',
+ 'the plastic {}.',
+ 'a photo of the cool {}.',
+ 'a close-up photo of a {}.',
+ 'a black and white photo of the {}.',
+ 'a painting of the {}.',
+ 'a painting of a {}.',
+ 'a pixelated photo of the {}.',
+ 'a sculpture of the {}.',
+ 'a bright photo of the {}.',
+ 'a cropped photo of a {}.',
+ 'a plastic {}.',
+ 'a photo of the dirty {}.',
+ 'a jpeg corrupted photo of a {}.',
+ 'a blurry photo of the {}.',
+ 'a photo of the {}.',
+ 'a good photo of the {}.',
+ 'a rendering of the {}.',
+ 'a {} in a video game.',
+ 'a photo of one {}.',
+ 'a doodle of a {}.',
+ 'a close-up photo of the {}.',
+ 'a photo of a {}.',
+ 'the origami {}.',
+ 'the {} in a video game.',
+ 'a sketch of a {}.',
+ 'a doodle of the {}.',
+ 'a origami {}.',
+ 'a low resolution photo of a {}.',
+ 'the toy {}.',
+ 'a rendition of the {}.',
+ 'a photo of the clean {}.',
+ 'a photo of a large {}.',
+ 'a rendition of a {}.',
+ 'a photo of a nice {}.',
+ 'a photo of a weird {}.',
+ 'a blurry photo of a {}.',
+ 'a cartoon {}.',
+ 'art of a {}.',
+ 'a sketch of the {}.',
+ 'a embroidered {}.',
+ 'a pixelated photo of a {}.',
+ 'itap of the {}.',
+ 'a jpeg corrupted photo of the {}.',
+ 'a good photo of a {}.',
+ 'a plushie {}.',
+ 'a photo of the nice {}.',
+ 'a photo of the small {}.',
+ 'a photo of the weird {}.',
+ 'the cartoon {}.',
+ 'art of the {}.',
+ 'a drawing of the {}.',
+ 'a photo of the large {}.',
+ 'a black and white photo of a {}.',
+ 'the plushie {}.',
+ 'a dark photo of a {}.',
+ 'itap of a {}.',
+ 'graffiti of the {}.',
+ 'a toy {}.',
+ 'itap of my {}.',
+ 'a photo of a cool {}.',
+ 'a photo of a small {}.',
+ 'a tattoo of the {}.',
+]
+
+part_templates = [
+ 'the paw of a {}.',
+ 'the nose of a {}.',
+ 'the eye of the {}.',
+ 'the ears of a {}.',
+ 'an eye of a {}.',
+ 'the tongue of a {}.',
+ 'the fur of the {}.',
+ 'colorful {} fur.',
+ 'a snout of a {}.',
+ 'the teeth of the {}.',
+ 'the {}s fangs.',
+ 'a claw of the {}.',
+ 'the face of the {}',
+ 'a neck of a {}',
+ 'the head of the {}',
+]
+
+imagenet_templates_small = [
+ 'a photo of a {}.',
+ 'a rendering of a {}.',
+ 'a cropped photo of the {}.',
+ 'the photo of a {}.',
+ 'a photo of a clean {}.',
+ 'a photo of a dirty {}.',
+ 'a dark photo of the {}.',
+ 'a photo of my {}.',
+ 'a photo of the cool {}.',
+ 'a close-up photo of a {}.',
+ 'a bright photo of the {}.',
+ 'a cropped photo of a {}.',
+ 'a photo of the {}.',
+ 'a good photo of the {}.',
+ 'a photo of one {}.',
+ 'a close-up photo of the {}.',
+ 'a rendition of the {}.',
+ 'a photo of the clean {}.',
+ 'a rendition of a {}.',
+ 'a photo of a nice {}.',
+ 'a good photo of a {}.',
+ 'a photo of the nice {}.',
+ 'a photo of the small {}.',
+ 'a photo of the weird {}.',
+ 'a photo of the large {}.',
+ 'a photo of a cool {}.',
+ 'a photo of a small {}.',
+]
diff --git a/GET3D_NADA/train_nada.py b/GET3D_NADA/train_nada.py
new file mode 100644
index 0000000..beb9037
--- /dev/null
+++ b/GET3D_NADA/train_nada.py
@@ -0,0 +1,349 @@
+"""
+train script of GeneratorGET3DNADA model
+
+Usage
+ - $ python train_nada.py --config_path [config_path] --name [exp_name] --suppress
+ - $ cat [config_path] | python train_nada.py --pipe --name [exp_name] --suppress
+
+Reference: https://github.com/rinongal/StyleGAN-nada/blob/main/ZSSGAN/train.py
+"""
+
+import sys
+import os
+
+import time
+import tempfile
+import yaml
+import numpy as np
+import torch
+from torchvision.utils import save_image
+import logging
+
+import dist_util
+from model_engine import find_get3d
+from nada import GeneratorGET3DNADA
+from functional import unfreeze_generator_layers, generate_nada_mode
+
+if find_get3d():
+ from torch_utils import custom_ops
+
+SEED = 0
+SELECT = 50
+
+
+def get_logger(exp_name, outdir, rank=0):
+ logger = logging.getLogger(exp_name)
+ if rank != 0:
+ logger.disabled = True
+ else:
+ logger.setLevel(logging.DEBUG)
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(formatter)
+ stream_handler.setLevel(logging.DEBUG)
+ logger.addHandler(stream_handler)
+ file_handler = logging.FileHandler(f'{outdir}/{exp_name}_{time.strftime("%Y-%m-%d-%H-%M", time.gmtime())}.log')
+ file_handler.setFormatter(formatter)
+ file_handler.setLevel(logging.DEBUG)
+ logger.addHandler(file_handler)
+ return logger
+
+
+def subprocess_fn(rank, config, args, temp_dir):
+
+ if config['GLOBAL']['gpus'] > 1:
+ dist_util.setup_dist(temp_dir, rank, config['GLOBAL']['gpus'])
+ try:
+ train(rank, config, args, temp_dir)
+ finally:
+ if dist_util.is_initialized():
+ dist_util.dist.destroy_process_group()
+
+
+def train(rank, config, args, temp_dir):
+
+ if rank != 0:
+ custom_ops.verbosity = 'none'
+
+ if rank == 0:
+ print("STARTING EXPERIENCE WITH NAME : ", args.name)
+ print("LOADING GeneratorGET3DNADA")
+
+ with dist_util.synchronized_ops():
+ net = GeneratorGET3DNADA(config)
+ unfreeze_generator_layers(net.generator_trainable, [], [])
+
+ if dist_util.get_world_size() > 1:
+ ddp_net = torch.nn.parallel.DistributedDataParallel(
+ net,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=True,
+ bucket_cap_mb=256,
+ find_unused_parameters=True,
+ )
+ else:
+ ddp_net = net
+
+ device, outdir, batch, n_vis, sample_1st, sample_2nd, iter_1st, iter_2nd, lr, \
+ output_interval, save_interval, gradient_clip_threshold = net.get_loop_settings()
+
+ g_optim = torch.optim.Adam(
+ net.generator_trainable.parameters(),
+ lr=lr,
+ betas=(0.9, 0.99),
+ )
+
+ with dist_util.synchronized_ops():
+ if rank == 0:
+ sample_dir = os.path.join(outdir, "sample")
+ ckpt_dir = os.path.join(outdir, "checkpoint")
+ os.makedirs(outdir, exist_ok=True)
+ os.makedirs(sample_dir, exist_ok=True)
+ os.makedirs(ckpt_dir, exist_ok=True)
+
+ torch.manual_seed(SEED)
+ np.random.seed(SEED)
+
+ logger = get_logger(args.name, outdir, rank=rank)
+ logger.info(f'EXPERIENCE NAME : {args.name} | CONFIG : {args.config_path} | SEED : {SEED} | BATCH : {batch}')
+
+ z_dim = 512 # Fixed value
+ fixed_z_geo = torch.randn(n_vis, z_dim, device=device) # for eval
+ fixed_z_tex = torch.randn(n_vis, z_dim, device=device)
+ grid_rows = int(n_vis ** 0.5)
+
+ eval_camera = net.generator_frozen.synthesis.generate_rotate_camera_list(n_batch=1)[4].repeat(n_vis, 1, 1, 1)
+ # ------------------ Training 1st --------------
+
+ # latent z should be 2 -> for geo , tex
+ # different n_batch latents per gpu <- equals: seeing n_batch * n_gpu latents
+ latent_generator = torch.Generator(device)
+ latent_generator.manual_seed(rank)
+ sample_z_geo = torch.randn(sample_1st, z_dim, device=device, generator=latent_generator)
+ sample_z_tex = torch.randn(sample_1st, z_dim, device=device, generator=latent_generator)
+
+ sample_z_geo_chunks = torch.split(sample_z_geo, batch, dim=0)
+ sample_z_tex_chunks = torch.split(sample_z_tex, batch, dim=0)
+ logger.info(f'START TRAINING LOOP')
+
+ min_loss_store = []
+
+ for epoch in range(iter_1st):
+ for i, (z_geo_chunk, z_tex_chunk) in enumerate(zip(sample_z_geo_chunks, sample_z_tex_chunks)):
+ # training
+ ddp_net.train()
+
+ # memory-efficient forward : support n_view rendering
+ _, loss = ddp_net(z_tex_chunk, z_geo_chunk)
+
+ if epoch == iter_1st - 1: # to choose 50 latents with low loss value
+ loss_val = loss.cpu().detach().numpy().tolist()
+ min_loss_store += loss_val
+
+ loss = loss.mean()
+ ddp_net.zero_grad()
+ loss.backward()
+
+ if gradient_clip_threshold == -1:
+ pass
+ else:
+ torch.nn.utils.clip_grad_norm_(net.generator_trainable.parameters(), gradient_clip_threshold)
+
+ g_optim.step()
+ logger.info(f'EPOCH : {epoch} | STEP : {i:0>4} | LOSS : {loss:.5f}')
+
+ # evaluation & save results | save checkpoints
+ with dist_util.synchronized_ops():
+ if rank == 0:
+ if i % output_interval == 0:
+ ddp_net.eval()
+ with torch.no_grad():
+ sampled_dst, _ = generate_nada_mode(
+ net.generator_trainable,
+ fixed_z_tex, fixed_z_geo,
+ use_mapping=True, mode='layer', camera=eval_camera
+ )
+
+ rgb = sampled_dst[:, :-1]
+ mask = sampled_dst[:, -1:]
+ bg = torch.ones(rgb.shape, device=device)
+ bg *= 0.0001 # for better background
+ new_dst = rgb*mask + bg*(1-mask)
+
+ save_image(
+ new_dst,
+ os.path.join(sample_dir, f"Iter1st_Epoch-{epoch}_Step-{i:0>4}.png"),
+ nrow=grid_rows,
+ normalize=True,
+ range=(-1, 1),
+ )
+ logger.info(f'ITER 1st | EPOCH : {epoch} | STEP : {i:0>4} | >> Save images ...')
+
+ if i % save_interval == 0 and not args.suppress:
+ torch.save(
+ {
+ "g_ema": net.generator_trainable.state_dict(),
+ "g_optim": g_optim.state_dict(),
+ },
+ f"{ckpt_dir}/Iter1st_Epoch-{epoch}_Step-{i:0>4}.pt",
+ )
+ logger.info(f'ITER 1st | EPOCH : {epoch} | STEP : {i:0>4} | >> Save checkpoint ...')
+
+ torch.cuda.empty_cache()
+
+ dist_util.barrier()
+
+ logger.info(f"SELCT TOP {SELECT} Latents")
+ min_topk_val, min_topk_idx = torch.topk(torch.tensor(min_loss_store), SELECT, largest=False)
+ print("SELECT : ", min_topk_val, min_topk_idx)
+
+ # ------------------ Training 2nd --------------
+
+ selected_z_geo = sample_z_geo[min_topk_idx]
+ selected_z_tex = sample_z_tex[min_topk_idx]
+
+ selected_z_geo_chunks = torch.split(selected_z_geo, batch, dim=0)
+ selected_z_tex_chunks = torch.split(selected_z_tex, batch, dim=0)
+
+ min_loss = 1000
+
+ for epoch in range(iter_2nd):
+ for i, (z_geo_chunk, z_tex_chunk) in enumerate(zip(selected_z_geo_chunks, selected_z_tex_chunks)):
+ # training
+ ddp_net.train()
+
+ _, loss = ddp_net(z_tex_chunk, z_geo_chunk)
+
+ loss = loss.mean()
+ ddp_net.zero_grad()
+ loss.backward()
+
+ if gradient_clip_threshold == -1:
+ pass
+ else:
+ torch.nn.utils.clip_grad_norm_(net.generator_trainable.parameters(), gradient_clip_threshold)
+
+ logger.info(f'ITER 2nd | EPOCH : {epoch} | STEP : {i:0>4} | LOSS : {loss:.5f}')
+
+ # evaluation & save results | save checkpoints
+ with dist_util.synchronized_ops():
+ if rank == 0:
+ if (i == len(selected_z_geo_chunks) - 1) and (epoch == iter_2nd - 1):
+ torch.save(
+ {
+ "g_ema": net.generator_trainable.state_dict(),
+ "g_optim": g_optim.state_dict(),
+ },
+ f"{ckpt_dir}/latest.pt",
+ )
+
+ if i % output_interval == 0:
+ ddp_net.eval()
+
+ with torch.no_grad():
+ sampled_dst, _ = generate_nada_mode(
+ net.generator_trainable,
+ fixed_z_tex, fixed_z_geo, use_mapping=True, mode='layer', camera=eval_camera
+ )
+
+ rgb = sampled_dst[:, :-1]
+ mask = sampled_dst[:, -1:]
+ bg = torch.ones(rgb.shape, device=device)
+ bg *= 0.0001 # for better background
+ new_dst = rgb*mask + bg*(1-mask)
+
+ save_image(
+ new_dst,
+ os.path.join(sample_dir, f"Iter2nd_Epoch-{epoch}_Step-{i:0>4}.png"),
+ nrow=grid_rows,
+ normalize=True,
+ range=(-1, 1),
+ )
+
+ logger.info(f'ITER 2nd | EPOCH : {epoch} | STEP : {i:0>4} | >> Save images ...')
+
+ if i % save_interval == 0:
+ if not args.suppress:
+ torch.save(
+ {
+ "g_ema": net.generator_trainable.state_dict(),
+ "g_optim": g_optim.state_dict(),
+ },
+ f"{ckpt_dir}/Iter2nd_Epoch-{epoch}_Step-{i:0>4}.pt",
+ )
+
+ logger.info(f'ITER 2nd | EPOCH : {epoch} | STEP : {i:0>4} | >> Save checkpoint ...')
+
+ if loss < min_loss:
+ min_loss = loss
+ torch.save(
+ {
+ "g_ema": net.generator_trainable.state_dict(),
+ "g_optim": g_optim.state_dict(),
+ },
+ f"{ckpt_dir}/best.pt",
+ )
+
+ torch.cuda.empty_cache()
+ dist_util.barrier()
+
+ logger.info("TRAINING DONE ...")
+
+ # Check final results
+ with dist_util.synchronized_ops():
+ if rank == 0:
+ net.eval()
+
+ with torch.no_grad():
+ last_z_geo = torch.randn(n_vis, z_dim, device=device)
+ last_z_tex = torch.randn(n_vis, z_dim, device=device)
+ sampled_dst, _ = generate_nada_mode(
+ net.generator_trainable,
+ last_z_tex, last_z_geo, use_mapping=True, mode='layer', camera=eval_camera
+ )
+
+ save_image(
+ sampled_dst,
+ os.path.join(sample_dir, "params_latest_images.png"),
+ nrow=grid_rows,
+ normalize=True,
+ range=(-1, 1),
+ )
+
+ logger.info("FINISHED")
+
+
+def launch_training(args): # Multiprocessing spawning function
+ # Load config and parse the number of GPUs.
+ if args.pipe:
+ config = yaml.safe_load(sys.stdin)
+ else:
+ with open(args.config_path, 'r') as f:
+ config = yaml.safe_load(f)
+ gpus = config['GLOBAL']['gpus']
+
+ # In case of single GPU, directly call the training function.
+ if gpus == 1:
+ subprocess_fn(0, config, args, None)
+ return
+
+ # Otherwise, launch processes.
+ print('Launching processes...')
+ torch.multiprocessing.set_start_method('spawn', force=True)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(config, args, temp_dir), nprocs=gpus)
+
+
+def parse_args():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config_path', type=str, default='experiments/default_dist.yaml')
+ parser.add_argument('--name', type=str, default='default_dist')
+ parser.add_argument('--pipe', action='store_true', help='read config from stdin instead of file')
+ parser.add_argument('--suppress', action='store_true')
+ return parser.parse_args()
+
+
+if __name__ == '__main__':
+ launch_training(parse_args())