diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..616ba9ed --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,22 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "CUDA_VISIBLE_DEVICES": "0" + }, + "args": [ // Pass arguments here + "--fname", "configs/pretrain/vith16_384.yaml" + ], + } + ] +} \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/main.py b/app/main.py index 52e1596a..62af1c3e 100644 --- a/app/main.py +++ b/app/main.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -17,55 +17,59 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--fname', type=str, - help='name of config file to load', - default='configs.yaml') + "--fname", type=str, help="name of config file to load", default="configs/pretrain/vith16_384.yaml" +) parser.add_argument( - '--devices', type=str, nargs='+', default=['cuda:0'], - help='which devices to use on local machine') + "--devices", + type=str, + nargs="+", + default=["cuda:0"], + help="which devices to use on local machine", +) def process_main(rank, fname, world_size, devices): import os - os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) + + os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1]) import logging from src.utils.logging import get_logger + logger = get_logger(force=True) if rank == 0: logger.setLevel(logging.INFO) else: logger.setLevel(logging.ERROR) - logger.info(f'called-params {fname}') + logger.info(f"called-params {fname}") # Load config params = None - with open(fname, 'r') as y_file: + with open(fname, "r") as y_file: params = yaml.load(y_file, Loader=yaml.FullLoader) - logger.info('loaded params...') + logger.info("loaded params...") # Log config if rank == 0: pprint.PrettyPrinter(indent=4).pprint(params) - dump = os.path.join(params['logging']['folder'], 'params-pretrain.yaml') - with open(dump, 'w') as f: + dump = os.path.join(params["logging"]["folder"], "params-pretrain.yaml") + with open(dump, "w") as f: yaml.dump(params, f) # Init distributed (access to comm between GPUS on same machine) world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) - logger.info(f'Running... (rank: {rank}/{world_size})') + logger.info(f"Running... (rank: {rank}/{world_size})") # Launch the app with loaded config - app_main(params['app'], args=params) + app_main(params["app"], args=params) -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() num_gpus = len(args.devices) - mp.set_start_method('spawn') + mp.set_start_method("spawn") for rank in range(num_gpus): mp.Process( - target=process_main, - args=(rank, args.fname, num_gpus, args.devices) + target=process_main, args=(rank, args.fname, num_gpus, args.devices) ).start() diff --git a/app/main_distributed.py b/app/main_distributed.py index 11ac3a27..defc0987 100644 --- a/app/main_distributed.py +++ b/app/main_distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -20,32 +20,33 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--folder', type=str, - help='location to save submitit logs', - default='/fsx-jepa/massran/submitit/') + "--folder", + type=str, + help="location to save submitit logs", + default="/fsx-jepa/massran/submitit/", +) parser.add_argument( - '--exclude', type=str, - help='nodes to exclude from training', - default=None) + "--exclude", type=str, help="nodes to exclude from training", default=None +) parser.add_argument( - '--batch-launch', action='store_true', - help='whether fname points to a file to batch-lauch several config files') + "--batch-launch", + action="store_true", + help="whether fname points to a file to batch-lauch several config files", +) parser.add_argument( - '--fname', type=str, - help='yaml file containing config file names to launch', - default='configs.yaml') -parser.add_argument( - '--partition', type=str, - help='cluster partition to submit jobs on') -parser.add_argument( - '--time', type=int, default=4300, - help='time in minutes to run job') + "--fname", + type=str, + help="yaml file containing config file names to launch", + default="configs.yaml", +) +parser.add_argument("--partition", type=str, help="cluster partition to submit jobs on") +parser.add_argument("--time", type=int, default=4300, help="time in minutes to run job") class Trainer: def __init__(self, args_pretrain, load_model=None): - self.app = args_pretrain['app'] + self.app = args_pretrain["app"] self.args_pretrain = args_pretrain self.load_model = load_model @@ -54,7 +55,7 @@ def __call__(self): params = self.args_pretrain load_model = self.load_model - logger.info('loaded pretrain params...') + logger.info("loaded pretrain params...") pp = pprint.PrettyPrinter(indent=4) pp.pprint(params) @@ -64,7 +65,9 @@ def __call__(self): def checkpoint(self): fb_trainer = Trainer(self.args_pretrain, True) - return submitit.helpers.DelayedSubmission(fb_trainer,) + return submitit.helpers.DelayedSubmission( + fb_trainer, + ) def launch_app_with_parsed_args( @@ -74,19 +77,20 @@ def launch_app_with_parsed_args( timeout=4300, nodes=1, tasks_per_node=1, - exclude_nodes=None + exclude_nodes=None, ): executor = submitit.AutoExecutor( - folder=os.path.join(submitit_folder, 'job_%j'), - slurm_max_num_timeout=20) + folder=os.path.join(submitit_folder, "job_%j"), slurm_max_num_timeout=20 + ) executor.update_parameters( slurm_partition=partition, - slurm_mem_per_gpu='55G', + slurm_mem_per_gpu="55G", timeout_min=timeout, nodes=nodes, tasks_per_node=tasks_per_node, cpus_per_task=12, - gpus_per_node=tasks_per_node) + gpus_per_node=tasks_per_node, + ) if args.exclude is not None: executor.update_parameters(slurm_exclude=args.exclude) @@ -95,7 +99,9 @@ def launch_app_with_parsed_args( with executor.batch(): for ap in args_for_pretrain: fb_trainer = Trainer(ap) - job = executor.submit(fb_trainer,) + job = executor.submit( + fb_trainer, + ) trainers.append(fb_trainer) jobs.append(job) @@ -114,7 +120,7 @@ def launch(): # -- config, but actually specifies a list of other config files # -- to run in a slurm job array if args.batch_launch: - with open(args.fname, 'r') as y_file: + with open(args.fname, "r") as y_file: config_fnames = yaml.load(y_file, Loader=yaml.FullLoader) # ---------------------------------------------------------------------- # @@ -124,13 +130,13 @@ def launch(): nodes, tasks_per_node = None, None configs = [] for f in config_fnames: - with open(f, 'r') as y_file: + with open(f, "r") as y_file: _params = yaml.load(y_file, Loader=yaml.FullLoader) - nodes = int(_params.get('nodes')) - tasks_per_node = int(_params.get('tasks_per_node')) + nodes = int(_params.get("nodes")) + tasks_per_node = int(_params.get("tasks_per_node")) configs += [_params] - logger.info(f'Loaded {len(configs)} config files') - logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}') + logger.info(f"Loaded {len(configs)} config files") + logger.info(f"Running all jobs with {nodes=} / {tasks_per_node=}") # ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- # @@ -143,10 +149,11 @@ def launch(): timeout=args.time, nodes=nodes, tasks_per_node=tasks_per_node, - exclude_nodes=args.exclude) + exclude_nodes=args.exclude, + ) # ---------------------------------------------------------------------- # -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() launch() diff --git a/app/main_with_actions.py b/app/main_with_actions.py new file mode 100644 index 00000000..3a2b3b87 --- /dev/null +++ b/app/main_with_actions.py @@ -0,0 +1,76 @@ +# In app/main_with_actions.py + +# Copyright (c) NeoCybernetica, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import argparse +import pprint +import yaml +import os +import logging +import traceback + +from app.scaffold import main as app_main +from src.utils.distributed import init_distributed +from app.vjepa.train_with_actions import main as train # Import the main function from train_with_actions.py + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--fname", + type=str, + help="name of config file to load", + default="configs/pretrain/vith16_384.yaml", + ) + parser.add_argument( + "--devices", + type=str, + nargs="+", + default=["cuda:0"], + help="which devices to use on local machine", + ) + + args = parser.parse_args() + + # Initialize logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + logger.info(f"Called parameters: {args.fname}") + + # Load configuration from YAML file + with open(args.fname, "r") as y_file: + params = yaml.load(y_file, Loader=yaml.FullLoader) + logger.info("Loaded configuration parameters.") + + # Pretty print the configuration parameters + pprint.PrettyPrinter(indent=4).pprint(params) + + # Save the configuration parameters to a YAML file + dump_file = os.path.join(params["logging"]["folder"], "params-pretrain.yaml") + os.makedirs(os.path.dirname(dump_file), exist_ok=True) + with open(dump_file, "w") as f: + yaml.dump(params, f) + + # Initialize distributed training (for single GPU, world_size and rank will be 1 and 0 respectively) + num_gpus = len(args.devices) + rank = 0 # Since you're on a single GPU + world_size, rank = init_distributed(rank_and_world_size=(rank, num_gpus)) # Update for single GPU + logger.info(f"Running... (rank: {rank}/{world_size})") + + # Setup environment variables for GPU visibility + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.devices[rank].split(":")[-1]) + + # Launch the app with loaded config + try: + train(args=params, world_size=world_size, rank=rank) + except Exception as e: + logger.error(f"An error occurred during training: {traceback.format_exc}") + raise e + +if __name__ == "__main__": + main() diff --git a/app/scaffold.py b/app/scaffold.py index 1b49a8b3..d2f3b3e0 100644 --- a/app/scaffold.py +++ b/app/scaffold.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -15,7 +15,7 @@ def main(app, args, resume_preempt=False): - logger.info(f'Running pre-training of app: {app}') - return importlib.import_module(f'app.{app}.train').main( - args=args, - resume_preempt=resume_preempt) + logger.info(f"Running pre-training of app: {app}") + return importlib.import_module(f"app.{app}.train").main( + args=args, resume_preempt=resume_preempt + ) diff --git a/app/vjepa/__init__.py b/app/vjepa/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/vjepa/test_validate_data_loading_pipeline.py b/app/vjepa/test_validate_data_loading_pipeline.py new file mode 100644 index 00000000..c6ffad76 --- /dev/null +++ b/app/vjepa/test_validate_data_loading_pipeline.py @@ -0,0 +1,74 @@ +from torch.utils.data import DataLoader +from torchvision import transforms +import yaml +from src.datasets.image_dataset import ImageDataset, SequentialDriveSampler +from src.masks.random_tube import MaskCollatorWithActions as TubeMaskCollator +from src.masks.multiblock3d import MaskCollatorWithActions as MB3DMaskCollator +from src.utils.logging import ( + get_logger, +) + +logger = get_logger(__name__) + +data_dir = "/home/ncdev/Documents/darwin/data/raw" +filename = "/home/ncdev/Documents/darwin/jepa/configs/pretrain/vith16_384.yaml" + +# Load configuration from YAML file +with open(filename, "r") as y_file: + params = yaml.load(y_file, Loader=yaml.FullLoader) +logger.info("Loaded configuration parameters.") + +def test_data_loader(data_dir, batch_size, mask_collator): + # Define the necessary transforms + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: x[:3] if x.size(0) > 3 else x), # Convert to RGB if needed + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + dataset = ImageDataset(data_dir=data_dir, transform=transform) + sampler = SequentialDriveSampler(dataset) + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + collate_fn=mask_collator, + num_workers=4, + pin_memory=True, + drop_last=True, + ) + + for batch_idx, (images, maneuvers, masks_enc, masks_dec) in enumerate(data_loader): + print(f"Batch {batch_idx + 1}") + print(f"Images shape: {images.shape}") + print(f"Maneuvers shape: {maneuvers.shape}") + print(f"Encoder Masks shape: {masks_enc[0].shape}") + print(f"Decoder Masks shape: {masks_dec[0].shape}") + print("---") + + if batch_idx == 4: + break + +cfgs_mask = params.get("mask") + +# Test with TubeMaskCollator +print("Testing with TubeMaskCollator") +tube_mask_collator = TubeMaskCollator( + crop_size=224, + num_frames=16, + patch_size=16, + tubelet_size=2, + cfgs_mask=cfgs_mask, +) +test_data_loader(data_dir=data_dir, batch_size=32, mask_collator=tube_mask_collator) + +# Test with MB3DMaskCollator +print("Testing with MB3DMaskCollator") +mb3d_mask_collator = MB3DMaskCollator( + crop_size=224, + num_frames=16, + patch_size=16, + tubelet_size=2, + cfgs_mask=cfgs_mask, +) +test_data_loader(data_dir=data_dir, batch_size=32, mask_collator=mb3d_mask_collator) \ No newline at end of file diff --git a/app/vjepa/train.py b/app/vjepa/train.py index 2b556168..4002f118 100644 --- a/app/vjepa/train.py +++ b/app/vjepa/train.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -20,6 +20,7 @@ import copy import time import numpy as np +import traceback import torch import torch.multiprocessing as mp @@ -343,7 +344,7 @@ def save_checkpoint(epoch, path): try: torch.save(save_dict, path) except Exception as e: - logger.info(f'Encountered exception when saving checkpoint: {e}') + logger.info(f'Encountered exception when saving checkpoint: {traceback.format_exc}') logger.info('Initializing loader...') loader = iter(unsupervised_loader) @@ -583,4 +584,4 @@ def log_stats(): if save_every_freq > 0 and epoch % save_every_freq == 0: save_every_file = f'{tag}-e{epoch}.pth.tar' save_every_path = os.path.join(folder, save_every_file) - save_checkpoint(epoch + 1, save_every_path) + save_checkpoint(epoch + 1, save_every_path) \ No newline at end of file diff --git a/app/vjepa/train_with_actions.py b/app/vjepa/train_with_actions.py new file mode 100644 index 00000000..5190e497 --- /dev/null +++ b/app/vjepa/train_with_actions.py @@ -0,0 +1,675 @@ +# Copyright (c) NeoCybernetica, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + + +# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS +try: + # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE + # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE + # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE + # -- TO EACH PROCESS + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] +except Exception: + pass + +import copy +import time +import numpy as np +import traceback + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel +from torchvision.transforms import ToPILImage + +from einops import rearrange + +from src.datasets.data_manager import init_data +from src.masks.random_tube import MaskCollatorWithActions as TubeMaskCollatorWithActions +from src.masks.multiblock3d import MaskCollatorWithActions as MB3DMaskCollator +from src.masks.utils import apply_masks +from src.utils.distributed import init_distributed, AllReduce +from src.utils.logging import ( + CSVLogger, + gpu_timer, + get_logger, + grad_logger, + adamw_logger, + AverageMeter, +) +from src.utils.tensors import repeat_interleave_batch, to_batch +from src.models.utils.combine_encodings import ( + combine_encodings_concat, + combine_encodings_add, + AttentionFusion, +) + + +from app.vjepa.utils import ( + load_checkpoint, + init_video_model, + init_opt, +) +from app.vjepa.transforms import make_image_transforms + + +# -- +log_timings = True +log_freq = 10 +checkpoint_freq = 1 +# -- + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + + +logger = get_logger(__name__) + +def main(args, world_size=1, rank=0, resume_preempt=False): + # First let's go over the folders and generate the + + # ----------------------------------------------------------------------- # + # PASSED IN PARAMS FROM CONFIG FILE + # ----------------------------------------------------------------------- # + + # -- META + cfgs_meta = args.get("meta") + load_model = cfgs_meta.get("load_checkpoint") or resume_preempt + r_file = cfgs_meta.get("read_checkpoint", None) + seed = cfgs_meta.get("seed", _GLOBAL_SEED) + save_every_freq = cfgs_meta.get("save_every_freq", -1) + skip_batches = cfgs_meta.get("skip_batches", -1) + use_sdpa = cfgs_meta.get("use_sdpa", False) + which_dtype = cfgs_meta.get("dtype") + logger.info(f"{which_dtype=}") + if which_dtype.lower() == "bfloat16": + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == "float16": + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get("mask") + + # -- MODEL + cfgs_model = args.get("model") + model_name = cfgs_model.get("model_name") + pred_depth = cfgs_model.get("pred_depth") + pred_embed_dim = cfgs_model.get("pred_embed_dim") + uniform_power = cfgs_model.get("uniform_power", True) + use_mask_tokens = cfgs_model.get("use_mask_tokens", True) + zero_init_mask_tokens = cfgs_model.get("zero_init_mask_tokens", True) + + # -- DATA + cfgs_data = args.get("data") + dataset_type = cfgs_data.get("dataset_type", "egovehicle_imagedataset") + mask_type = cfgs_data.get("mask_type", "multiblock3d") + dataset_paths = cfgs_data.get("datasets", []) + datasets_weights = cfgs_data.get("datasets_weights", None) + if datasets_weights is not None: + assert len(datasets_weights) == len( + dataset_paths + ), "Must have one sampling weight specified for each dataset" + batch_size = cfgs_data.get("batch_size") + num_clips = cfgs_data.get("num_clips") + num_frames = cfgs_data.get("num_frames") + tubelet_size = cfgs_data.get("tubelet_size") + sampling_rate = cfgs_data.get("sampling_rate") + duration = cfgs_data.get("clip_duration", None) + crop_size = cfgs_data.get("crop_size", 224) + patch_size = cfgs_data.get("patch_size") + pin_mem = cfgs_data.get("pin_mem", False) + num_workers = cfgs_data.get("num_workers", 1) + filter_short_videos = cfgs_data.get("filter_short_videos", False) + decode_one_clip = cfgs_data.get("decode_one_clip", True) + log_resource_util_data = cfgs_data.get("log_resource_utilization", False) + + # -- DATA AUGS + cfgs_data_aug = args.get("data_aug") + ar_range = cfgs_data_aug.get("random_resize_aspect_ratio", [3 / 4, 4 / 3]) + rr_scale = cfgs_data_aug.get("random_resize_scale", [0.3, 1.0]) + motion_shift = cfgs_data_aug.get("motion_shift", False) + reprob = cfgs_data_aug.get("reprob", 0.0) + use_aa = cfgs_data_aug.get("auto_augment", False) + + # -- LOSS + cfgs_loss = args.get("loss") + loss_exp = cfgs_loss.get("loss_exp") + reg_coeff = cfgs_loss.get("reg_coeff") + + # -- OPTIMIZATION + cfgs_opt = args.get("optimization") + ipe = cfgs_opt.get("ipe", None) + ipe_scale = cfgs_opt.get("ipe_scale", 1.0) + clip_grad = cfgs_opt.get("clip_grad", None) + wd = float(cfgs_opt.get("weight_decay")) + final_wd = float(cfgs_opt.get("final_weight_decay")) + num_epochs = cfgs_opt.get("epochs") + warmup = cfgs_opt.get("warmup") + start_lr = cfgs_opt.get("start_lr") + lr = cfgs_opt.get("lr") + final_lr = cfgs_opt.get("final_lr") + ema = cfgs_opt.get("ema") + betas = cfgs_opt.get("betas", (0.9, 0.999)) + eps = cfgs_opt.get("eps", 1.0e-8) + + # -- LOGGING + cfgs_logging = args.get("logging") + folder = cfgs_logging.get("folder") + tag = cfgs_logging.get("write_tag") + + # ----------------------------------------------------------------------- # + # ----------------------------------------------------------------------- # + + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = True + + # -- set device + if not torch.cuda.is_available(): + device = torch.device("cpu") + else: + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + # -- log/checkpointing paths + log_file = os.path.join(folder, f"{tag}_r{rank}.csv") + latest_file = f"{tag}-latest.pth.tar" + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = None + load_model = False + + # -- make csv_logger + csv_logger = CSVLogger( + log_file, + ("%d", "epoch"), + ("%d", "itr"), + ("%.5f", "loss"), + ("%.5f", "loss-jepa"), + ("%.5f", "reg-loss"), + ("%.5f", "enc-grad-norm"), + ("%.5f", "pred-grad-norm"), + ("%d", "gpu-time(ms)"), + ("%d", "wall-time(ms)"), + ) + + # -- init model + encoder, predictor, action_encoder = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + target_encoder = copy.deepcopy(encoder) + + # -- make data transforms + if mask_type == "multiblock3d": + logger.info("Initializing basic multi-block mask") + mask_collator = MB3DMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask, + ) + else: + logger.info("Initializing random tube mask") + mask_collator = TubeMaskCollatorWithActions( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask, + ) + transform = make_image_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=ar_range, + random_resize_scale=rr_scale, + reprob=reprob, + auto_augment=use_aa, + crop_size=crop_size, + ) + + # -- init data-loaders/samplers + (unsupervised_loader, unsupervised_sampler) = init_data( + data=dataset_type, + root_path=dataset_paths, + batch_size=batch_size, + training=True, + clip_len=num_frames, + frame_sample_rate=sampling_rate, + filter_short_videos=filter_short_videos, + decode_one_clip=decode_one_clip, + duration=duration, + num_clips=num_clips, + transform=transform, + datasets_weights=datasets_weights, + collator=mask_collator, + num_workers=num_workers, + world_size=world_size, + pin_mem=pin_mem, + rank=rank, + log_dir=folder if log_resource_util_data else None, + ) + try: + _dlen = len(unsupervised_loader) + except Exception: # Different interface for webdataset + _dlen = unsupervised_loader.num_batches + if ipe is None: + ipe = _dlen + logger.info(f"iterations per epoch/dataest length: {ipe}/{_dlen}") + + # -- init optimizer and scheduler + optimizer, scaler, scheduler, wd_scheduler = init_opt( + encoder=encoder, + predictor=predictor, + wd=wd, + final_wd=final_wd, + start_lr=start_lr, + ref_lr=lr, + final_lr=final_lr, + iterations_per_epoch=ipe, + warmup=warmup, + num_epochs=num_epochs, + ipe_scale=ipe_scale, + mixed_precision=mixed_precision, + betas=betas, + eps=eps, + ) + encoder = DistributedDataParallel(encoder, static_graph=True) + predictor = DistributedDataParallel(predictor, static_graph=True) + target_encoder = DistributedDataParallel(target_encoder) + for p in target_encoder.parameters(): + p.requires_grad = False + + # -- momentum schedule + momentum_scheduler = ( + ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) + for i in range(int(ipe * num_epochs * ipe_scale) + 1) + ) + + start_epoch = 0 + # -- load training checkpoint + if load_model or os.path.exists(latest_path): + ( + encoder, + predictor, + target_encoder, + optimizer, + scaler, + start_epoch, + ) = load_checkpoint( + r_path=load_path, + encoder=encoder, + predictor=predictor, + target_encoder=target_encoder, + opt=optimizer, + scaler=scaler, + ) + for _ in range(start_epoch * ipe): + scheduler.step() + wd_scheduler.step() + next(momentum_scheduler) + mask_collator.step() + + def save_checkpoint(epoch, path): + if rank != 0: + return + save_dict = { + "encoder": encoder.state_dict(), + "predictor": predictor.state_dict(), + "opt": optimizer.state_dict(), + "scaler": None if scaler is None else scaler.state_dict(), + "target_encoder": target_encoder.state_dict(), + "epoch": epoch, + "loss": loss_meter.avg, + "batch_size": batch_size, + "world_size": world_size, + "lr": lr, + } + try: + torch.save(save_dict, path) + except Exception as e: + logger.info(f"Encountered exception when saving checkpoint: {traceback.format_exc}") + + logger.info("Initializing loader...") + loader = iter(unsupervised_loader) + + # -- TRAINING LOOP + for epoch in range(start_epoch, num_epochs): + logger.info("Epoch %d" % (epoch + 1)) + + # -- update distributed-data-loader epoch + unsupervised_sampler.set_epoch(epoch) + + loss_meter = AverageMeter() + input_var_meter = AverageMeter() + input_var_min_meter = AverageMeter() + jepa_loss_meter = AverageMeter() + reg_loss_meter = AverageMeter() + mask_meters = [AverageMeter() for _ in range(len(cfgs_mask))] + gpu_time_meter = AverageMeter() + wall_time_meter = AverageMeter() + + for itr in range(ipe): + itr_start_time = time.time() + + try: + collated_images, collated_maneuvers, masks_enc, masks_pred = next(loader) + + except StopIteration: + logger.info( + "Exhausted data loaders before completing all planned iterations. Ending epoch early..." + ) + break # Exit the current epoch loop if there are no more data points to process + + assert len(masks_enc) == len( + masks_pred + ), "Currently require num encoder masks = num predictor masks" + + def load_images_and_actions(): + try: + images = [] + to_pil = ToPILImage() # Create an instance of ToPILImage + + for i in range(len(collated_images)): + image = collated_images[i] + image = to_pil(image) # Convert the PyTorch tensor to a PIL Image + image = transform(image) # Apply the transformation to the PIL image + images.append(image) + + # Stack the transformed images into a single batched tensor + images = torch.stack(images, dim=0).to(device, non_blocking=True) + + # -- Encode actions + encoded_actions = action_encoder(collated_maneuvers) + + # ... (load masks as before) + _masks_enc, _masks_pred = [], [] + for _me, _mp in zip(masks_enc, masks_pred): + _me = _me.to(device, non_blocking=True) + _mp = _mp.to(device, non_blocking=True) + _masks_enc.append(_me) + _masks_pred.append(_mp) + + return images, encoded_actions, _masks_enc, _masks_pred + except Exception as e: + logger.error(f"Error in load_images_and_actions: {str(e)}") + raise e + + + images, encoded_actions, masks_enc, masks_pred = load_images_and_actions() + + for _i, m in enumerate(mask_meters): + m.update(masks_enc[_i][0].size(-1)) + + def train_step(): + _new_lr = scheduler.step() + _new_wd = wd_scheduler.step() + # -- + + def forward_target(images): + """ + Encodes the target images using the target encoder and returns the embeddings. + + Args: + images (torch.Tensor): A tensor of shape [B, T, C, H, W] representing a batch + of image sequences. + + Returns: + torch.Tensor: A tensor of shape [B, T, D] representing the encoded image embeddings, + where D is the embedding dimension. + """ + with torch.no_grad(): + image_embeddings = target_encoder(images) + # Normalize the embeddings across the feature dimension + normalized_embeddings = F.layer_norm(image_embeddings, (image_embeddings.size(-1),)) + + # Extract the embeddings for the next frames as targets + next_frame_embeddings = normalized_embeddings[:, 1:, :] # Assuming frames are sequential + return next_frame_embeddings + + def forward_context(images, encoded_actions, h): + """ + Encodes context images with the encoder, combines with encoded actions, + and predicts masked regions using the predictor. + + Args: + images (torch.Tensor): A tensor of shape [B, T, C, H, W] representing a batch + of image sequences. + encoded_actions (torch.Tensor): A tensor of shape [B, T, A] representing encoded actions, + where A is the action embedding dimension. + h (torch.Tensor): The hidden state from the target encoder. (Ground truth) + + Returns: + torch.Tensor: A list of tensors representing the predicted values for the masked regions. + """ + try: + image_embeddings = encoder(images, masks_enc) + + # Combine image and action embeddings + combined_embeddings = combine_encodings_concat(image_embeddings, encoded_actions) + + # Predict masked regions + predictions = predictor(combined_embeddings, h, masks_enc, masks_pred) + return predictions + except Exception as e: + logger.error(f"Error in forward_context: {str(e)}") + raise e + + + def loss_fn(z_next, h_next): + loss = 0.0 + # Compute loss between predicted next frames and ground truth next frames + for zi, hi in zip(z_next, h_next): + loss += torch.mean(torch.abs(zi - hi) ** loss_exp) / loss_exp + loss /= len(h_next) + return loss + + def reg_fn(z): + return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len( + z + ) + + # Step 1. Forward + loss_jepa, loss_reg = 0.0, 0.0 + with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): + h_next = forward_target(images) + z_next = forward_context(images, h_next, encoded_actions) + loss_jepa = loss_fn(z_next, h_next) # jepa prediction loss + pstd_z = reg_fn(z_next) # predictor variance across patches + loss_reg += torch.mean(F.relu(1.0 - pstd_z)) + loss = loss_jepa + reg_coeff * loss_reg + + # Step 2. Backward & step + _enc_norm, _pred_norm = 0.0, 0.0 + if mixed_precision: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + else: + loss.backward() + if (epoch > warmup) and (clip_grad is not None): + _enc_norm = torch.nn.utils.clip_grad_norm_( + encoder.parameters(), clip_grad + ) + _pred_norm = torch.nn.utils.clip_grad_norm_( + predictor.parameters(), clip_grad + ) + if mixed_precision: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + grad_stats = grad_logger(encoder.named_parameters()) + grad_stats.global_norm = float(_enc_norm) + grad_stats_pred = grad_logger(predictor.named_parameters()) + grad_stats_pred.global_norm = float(_pred_norm) + optimizer.zero_grad() + optim_stats = adamw_logger(optimizer) + + # Step 3. momentum update of target encoder + m = next(momentum_scheduler) + with torch.no_grad(): + for param_q, param_k in zip( + encoder.parameters(), target_encoder.parameters() + ): + param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) + + return ( + float(loss), + float(loss_jepa), + float(loss_reg), + _new_lr, + _new_wd, + grad_stats, + grad_stats_pred, + optim_stats, + ) + + ( + loss, + loss_jepa, + loss_reg, + _new_lr, + _new_wd, + grad_stats, + grad_stats_pred, + optim_stats, + ), gpu_etime_ms = gpu_timer(train_step) + iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000.0 + loss_meter.update(loss) + input_var = float( + AllReduce.apply(images.view(images.shape[0], -1).var(dim=1).mean(dim=0)) + ) + input_var_min = float( + AllReduce.apply(torch.min(images.view(images.shape[0], -1).var(dim=1))) + ) + input_var_meter.update(input_var) + input_var_min_meter.update(input_var_min) + jepa_loss_meter.update(loss_jepa) + reg_loss_meter.update(loss_reg) + gpu_time_meter.update(gpu_etime_ms) + wall_time_meter.update(iter_elapsed_time_ms) + + # -- Logging + def log_stats(): + csv_logger.log( + epoch + 1, + itr, + loss, + loss_jepa, + loss_reg, + grad_stats.global_norm, + grad_stats_pred.global_norm, + gpu_etime_ms, + iter_elapsed_time_ms, + ) + if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): + logger.info( + "[%d, %5d] loss: %.3f | p%.3f r%.3f | " + "input_var: %.3f %.3f | " + "masks: %s " + "[wd: %.2e] [lr: %.2e] " + "[mem: %.2e] " + "[gpu: %.1f ms]" + "[wall: %.1f ms]" + % ( + epoch + 1, + itr, + loss_meter.avg, + jepa_loss_meter.avg, + reg_loss_meter.avg, + input_var_meter.avg, + input_var_min_meter.avg, + "[" + + ", ".join(["%.1f" % m.avg for m in mask_meters]) + + "]", + _new_wd, + _new_lr, + torch.cuda.max_memory_allocated() / 1024.0**2, + gpu_time_meter.avg, + wall_time_meter.avg, + ) + ) + + if optim_stats is not None: + logger.info( + "[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]" + % ( + epoch + 1, + itr, + optim_stats.get("exp_avg").avg, + optim_stats.get("exp_avg").min, + optim_stats.get("exp_avg").max, + optim_stats.get("exp_avg_sq").avg, + optim_stats.get("exp_avg_sq").min, + optim_stats.get("exp_avg_sq").max, + ) + ) + + if grad_stats is not None: + logger.info( + "[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e" + % ( + epoch + 1, + itr, + grad_stats.first_layer, + grad_stats.last_layer, + grad_stats.min, + grad_stats.max, + grad_stats.global_norm, + ) + ) + + if grad_stats_pred is not None: + logger.info( + "[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e" + % ( + epoch + 1, + itr, + grad_stats_pred.first_layer, + grad_stats_pred.last_layer, + grad_stats_pred.min, + grad_stats_pred.max, + grad_stats_pred.global_norm, + ) + ) + + log_stats() + assert not np.isnan(loss), "loss is nan" + + # -- Save Checkpoint + logger.info("avg. loss %.3f" % loss_meter.avg) + # -- Save Last + if epoch % checkpoint_freq == 0 or epoch == (num_epochs - 1): + save_checkpoint(epoch + 1, latest_path) + if save_every_freq > 0 and epoch % save_every_freq == 0: + save_every_file = f"{tag}-e{epoch}.pth.tar" + save_every_path = os.path.join(folder, save_every_file) + save_checkpoint(epoch + 1, save_every_path) diff --git a/app/vjepa/transforms.py b/app/vjepa/transforms.py index 0854dd9a..aa24ead7 100644 --- a/app/vjepa/transforms.py +++ b/app/vjepa/transforms.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -12,16 +12,50 @@ from src.datasets.utils.video.randerase import RandomErasing +def make_image_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=(3 / 4, 4 / 3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), +): + + transform_list = [ + transforms.RandomResizedCrop( + crop_size, + scale=random_resize_scale, + ratio=random_resize_aspect_ratio, + ), + ] + + if random_horizontal_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + if auto_augment: + transform_list.append(transforms.AutoAugment()) + + transform_list.extend([ + transforms.ToTensor(), + transforms.Normalize(mean=normalize[0], std=normalize[1]), + ]) + + if reprob > 0: + transform_list.append(transforms.RandomErasing(p=reprob)) + + return transforms.Compose(transform_list) + + def make_transforms( random_horizontal_flip=True, - random_resize_aspect_ratio=(3/4, 4/3), + random_resize_aspect_ratio=(3 / 4, 4 / 3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): _frames_augmentation = VideoTransform( @@ -42,14 +76,13 @@ class VideoTransform(object): def __init__( self, random_horizontal_flip=True, - random_resize_aspect_ratio=(3/4, 4/3), + random_resize_aspect_ratio=(3 / 4, 4 / 3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): self.random_horizontal_flip = random_horizontal_flip @@ -62,25 +95,28 @@ def __init__( self.std = torch.tensor(normalize[1], dtype=torch.float32) if not self.auto_augment: # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. - self.mean *= 255. - self.std *= 255. + self.mean *= 255.0 + self.std *= 255.0 self.autoaug_transform = video_transforms.create_random_augment( input_size=(crop_size, crop_size), - auto_augment='rand-m7-n4-mstd0.5-inc1', - interpolation='bicubic', + auto_augment="rand-m7-n4-mstd0.5-inc1", + interpolation="bicubic", ) - self.spatial_transform = video_transforms.random_resized_crop_with_shift \ - if motion_shift else video_transforms.random_resized_crop + self.spatial_transform = ( + video_transforms.random_resized_crop_with_shift + if motion_shift + else video_transforms.random_resized_crop + ) self.reprob = reprob self.erase_transform = RandomErasing( reprob, - mode='pixel', + mode="pixel", max_count=1, num_splits=1, - device='cpu', + device="cpu", ) def __call__(self, buffer): diff --git a/app/vjepa/utils.py b/app/vjepa/utils.py index dc8668dc..046018ab 100644 --- a/app/vjepa/utils.py +++ b/app/vjepa/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -9,16 +9,16 @@ import sys import warnings import yaml +import traceback import torch import src.models.vision_transformer as video_vit +from src.models.action_encoders import ActionEncoderContinuous, ActionEncoderDiscrete import src.models.predictor as vit_pred from src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper -from src.utils.schedulers import ( - WarmupCosineSchedule, - CosineWDSchedule) +from src.utils.schedulers import WarmupCosineSchedule, CosineWDSchedule from src.utils.tensors import trunc_normal_ logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -34,43 +34,43 @@ def load_checkpoint( scaler, ): try: - checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + checkpoint = torch.load(r_path, map_location=torch.device("cpu")) except Exception as e: - logger.info(f'Encountered exception when loading checkpoint {e}') + logger.info(f"Encountered exception when loading checkpoint {traceback.format_exc}") epoch = 0 try: - epoch = checkpoint['epoch'] + epoch = checkpoint["epoch"] # -- loading encoder - pretrained_dict = checkpoint['encoder'] + pretrained_dict = checkpoint["encoder"] msg = encoder.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}") # -- loading predictor - pretrained_dict = checkpoint['predictor'] + pretrained_dict = checkpoint["predictor"] msg = predictor.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}') + logger.info(f"loaded pretrained predictor from epoch {epoch} with msg: {msg}") # -- loading target_encoder if target_encoder is not None: print(list(checkpoint.keys())) - pretrained_dict = checkpoint['target_encoder'] + pretrained_dict = checkpoint["target_encoder"] msg = target_encoder.load_state_dict(pretrained_dict) logger.info( - f'loaded pretrained target encoder from epoch {epoch} with msg: {msg}' + f"loaded pretrained target encoder from epoch {epoch} with msg: {msg}" ) # -- loading optimizer - opt.load_state_dict(checkpoint['opt']) + opt.load_state_dict(checkpoint["opt"]) if scaler is not None: - scaler.load_state_dict(checkpoint['scaler']) - logger.info(f'loaded optimizers from epoch {epoch}') - logger.info(f'read-path: {r_path}') + scaler.load_state_dict(checkpoint["scaler"]) + logger.info(f"loaded optimizers from epoch {epoch}") + logger.info(f"read-path: {r_path}") del checkpoint except Exception as e: - logger.info(f'Encountered exception when loading checkpoint {e}') + logger.info(f"Encountered exception when loading checkpoint {traceback.format_exc}") epoch = 0 return ( @@ -88,7 +88,7 @@ def init_video_model( patch_size=16, num_frames=16, tubelet_size=2, - model_name='vit_base', + model_name="vit_base", crop_size=224, pred_depth=6, pred_embed_dim=384, @@ -97,6 +97,10 @@ def init_video_model( num_mask_tokens=2, zero_init_mask_tokens=True, use_sdpa=False, + action_type: str = "disc", # "cont", + num_actions=19, + embed_dim=32, + hidden_dim=32, ): encoder = video_vit.__dict__[model_name]( img_size=crop_size, @@ -107,7 +111,8 @@ def init_video_model( use_sdpa=use_sdpa, ) encoder = MultiMaskWrapper(encoder) - predictor = vit_pred.__dict__['vit_predictor']( + + predictor = vit_pred.__dict__["vit_predictor"]( img_size=crop_size, use_mask_tokens=use_mask_tokens, patch_size=patch_size, @@ -124,6 +129,13 @@ def init_video_model( ) predictor = PredictorMultiMaskWrapper(predictor) + if action_type == "disc": + action_encoder = ActionEncoderDiscrete( + num_actions=num_actions, embed_dim=embed_dim, hidden_dim=hidden_dim + ) + # else: + # action_encoder = ActionEncoderContinuous(input_dim=) + def init_weights(m): if isinstance(m, torch.nn.Linear): trunc_normal_(m.weight, std=0.02) @@ -147,10 +159,10 @@ def init_weights(m): def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info(f'Encoder number of parameters: {count_parameters(encoder)}') - logger.info(f'Predictor number of parameters: {count_parameters(predictor)}') + logger.info(f"Encoder number of parameters: {count_parameters(encoder)}") + logger.info(f"Predictor number of parameters: {count_parameters(predictor)}") - return encoder, predictor + return encoder, predictor, action_encoder def init_opt( @@ -172,25 +184,40 @@ def init_opt( ): param_groups = [ { - 'params': (p for n, p in encoder.named_parameters() - if ('bias' not in n) and (len(p.shape) != 1)) - }, { - 'params': (p for n, p in predictor.named_parameters() - if ('bias' not in n) and (len(p.shape) != 1)) - }, { - 'params': (p for n, p in encoder.named_parameters() - if ('bias' in n) or (len(p.shape) == 1)), - 'WD_exclude': zero_init_bias_wd, - 'weight_decay': 0, - }, { - 'params': (p for n, p in predictor.named_parameters() - if ('bias' in n) or (len(p.shape) == 1)), - 'WD_exclude': zero_init_bias_wd, - 'weight_decay': 0, + "params": ( + p + for n, p in encoder.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in predictor.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in encoder.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": zero_init_bias_wd, + "weight_decay": 0, + }, + { + "params": ( + p + for n, p in predictor.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": zero_init_bias_wd, + "weight_decay": 0, }, ] - logger.info('Using AdamW') + logger.info("Using AdamW") optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps) scheduler = WarmupCosineSchedule( optimizer, diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/configs/evals/__init__.py b/configs/evals/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/configs/pretrain/__init__.py b/configs/pretrain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/configs/pretrain/vith16_384.yaml b/configs/pretrain/vith16_384.yaml index 9c4055a3..9c9646a2 100644 --- a/configs/pretrain/vith16_384.yaml +++ b/configs/pretrain/vith16_384.yaml @@ -2,11 +2,11 @@ app: vjepa nodes: 30 tasks_per_node: 8 data: - dataset_type: VideoDataset + dataset_type: egovehicle_imagedataset datasets: - - /your_path_to_kinetics710_csv_file_index.csv - - /your_path_to_ssv2_csv_file_index.csv - - /your_path_to_howto100m_csv_file_index.csv + - /home/ncdev/Documents/darwin/data/raw/ + # - /your_path_to_ssv2_csv_file_index.csv + # - /your_path_to_howto100m_csv_file_index.csv decode_one_clip: true batch_size: 10 num_clips: 1 @@ -30,7 +30,7 @@ data_aug: - 1.0 reprob: 0.0 logging: - folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/ + folder: /home/ncdev/Documents/darwin/jepa/logs_and_checkpoints write_tag: jepa loss: loss_exp: 1.0 @@ -87,4 +87,4 @@ optimization: final_lr: 1.0e-06 ema: - 0.998 - - 1.0 + - 1.0 \ No newline at end of file diff --git a/evals/__init__.py b/evals/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/evals/image_classification_frozen/eval.py b/evals/image_classification_frozen/eval.py index 56d2f28e..0be356b0 100644 --- a/evals/image_classification_frozen/eval.py +++ b/evals/image_classification_frozen/eval.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -13,12 +13,13 @@ # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE # -- TO EACH PROCESS - os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] except Exception: pass import logging import pprint +import traceback import numpy as np @@ -35,18 +36,12 @@ from src.datasets.data_manager import ( init_data, ) -from src.utils.distributed import ( - init_distributed, - AllReduce -) +from src.utils.distributed import init_distributed, AllReduce from src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule, ) -from src.utils.logging import ( - AverageMeter, - CSVLogger -) +from src.utils.logging import AverageMeter, CSVLogger logging.basicConfig() logger = logging.getLogger() @@ -67,76 +62,75 @@ def main(args_eval, resume_preempt=False): # ----------------------------------------------------------------------- # # -- PRETRAIN - args_pretrain = args_eval.get('pretrain') - checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder') - model_name = args_pretrain.get('model_name', None) - patch_size = args_pretrain.get('patch_size', None) - pretrain_folder = args_pretrain.get('folder', None) - ckp_fname = args_pretrain.get('checkpoint', None) - tag = args_pretrain.get('write_tag', None) - use_sdpa = args_pretrain.get('use_sdpa', True) - use_SiLU = args_pretrain.get('use_silu', False) - tight_SiLU = args_pretrain.get('tight_silu', True) - uniform_power = args_pretrain.get('uniform_power', False) + args_pretrain = args_eval.get("pretrain") + checkpoint_key = args_pretrain.get("checkpoint_key", "target_encoder") + model_name = args_pretrain.get("model_name", None) + patch_size = args_pretrain.get("patch_size", None) + pretrain_folder = args_pretrain.get("folder", None) + ckp_fname = args_pretrain.get("checkpoint", None) + tag = args_pretrain.get("write_tag", None) + use_sdpa = args_pretrain.get("use_sdpa", True) + use_SiLU = args_pretrain.get("use_silu", False) + tight_SiLU = args_pretrain.get("tight_silu", True) + uniform_power = args_pretrain.get("uniform_power", False) pretrained_path = os.path.join(pretrain_folder, ckp_fname) # Optional [for Video model]: - tubelet_size = args_pretrain.get('tubelet_size', 2) - frames_per_clip = args_pretrain.get('frames_per_clip', 1) + tubelet_size = args_pretrain.get("tubelet_size", 2) + frames_per_clip = args_pretrain.get("frames_per_clip", 1) # -- DATA - args_data = args_eval.get('data') - dataset_name = args_data.get('dataset_name') - num_classes = args_data.get('num_classes') - root_path = args_data.get('root_path', None) - image_folder = args_data.get('image_folder', None) - resolution = args_data.get('resolution', 224) + args_data = args_eval.get("data") + dataset_name = args_data.get("dataset_name") + num_classes = args_data.get("num_classes") + root_path = args_data.get("root_path", None) + image_folder = args_data.get("image_folder", None) + resolution = args_data.get("resolution", 224) # -- OPTIMIZATION - args_opt = args_eval.get('optimization') - batch_size = args_opt.get('batch_size') - num_epochs = args_opt.get('num_epochs') - wd = args_opt.get('weight_decay') - start_lr = args_opt.get('start_lr') - lr = args_opt.get('lr') - final_lr = args_opt.get('final_lr') - warmup = args_opt.get('warmup') - use_bfloat16 = args_opt.get('use_bfloat16') + args_opt = args_eval.get("optimization") + batch_size = args_opt.get("batch_size") + num_epochs = args_opt.get("num_epochs") + wd = args_opt.get("weight_decay") + start_lr = args_opt.get("start_lr") + lr = args_opt.get("lr") + final_lr = args_opt.get("final_lr") + warmup = args_opt.get("warmup") + use_bfloat16 = args_opt.get("use_bfloat16") # -- EXPERIMENT-ID/TAG (optional) - resume_checkpoint = args_eval.get('resume_checkpoint', False) or resume_preempt - eval_tag = args_eval.get('tag', None) + resume_checkpoint = args_eval.get("resume_checkpoint", False) or resume_preempt + eval_tag = args_eval.get("tag", None) # ----------------------------------------------------------------------- # try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except Exception: pass if not torch.cuda.is_available(): - device = torch.device('cpu') + device = torch.device("cpu") else: - device = torch.device('cuda:0') + device = torch.device("cuda:0") torch.cuda.set_device(device) world_size, rank = init_distributed() - logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") # -- log/checkpointing paths - folder = os.path.join(pretrain_folder, 'image_classification_frozen/') + folder = os.path.join(pretrain_folder, "image_classification_frozen/") if eval_tag is not None: folder = os.path.join(folder, eval_tag) if not os.path.exists(folder): os.makedirs(folder, exist_ok=True) - log_file = os.path.join(folder, f'{tag}_r{rank}.csv') - latest_path = os.path.join(folder, f'{tag}-latest.pth.tar') + log_file = os.path.join(folder, f"{tag}_r{rank}.csv") + latest_path = os.path.join(folder, f"{tag}-latest.pth.tar") # -- make csv_logger if rank == 0: - csv_logger = CSVLogger(log_file, - ('%d', 'epoch'), - ('%.5f', 'loss'), - ('%.5f', 'acc')) + csv_logger = CSVLogger( + log_file, ("%d", "epoch"), ("%.5f", "loss"), ("%.5f", "acc") + ) # Initialize model @@ -153,7 +147,8 @@ def main(args_eval, resume_preempt=False): checkpoint_key=checkpoint_key, use_SiLU=use_SiLU, tight_SiLU=tight_SiLU, - use_sdpa=use_sdpa) + use_sdpa=use_sdpa, + ) encoder.eval() for p in encoder.parameters(): p.requires_grad = False @@ -163,7 +158,7 @@ def main(args_eval, resume_preempt=False): embed_dim=encoder.embed_dim, num_heads=encoder.num_heads, depth=1, - num_classes=num_classes + num_classes=num_classes, ).to(device) train_loader = make_dataloader( @@ -174,7 +169,8 @@ def main(args_eval, resume_preempt=False): batch_size=batch_size, world_size=world_size, rank=rank, - training=True) + training=True, + ) val_loader = make_dataloader( dataset_name=dataset_name, root_path=root_path, @@ -183,9 +179,10 @@ def main(args_eval, resume_preempt=False): batch_size=batch_size, world_size=world_size, rank=rank, - training=False) + training=False, + ) ipe = len(train_loader) - logger.info(f'Dataloader created... iterations per epoch: {ipe}') + logger.info(f"Dataloader created... iterations per epoch: {ipe}") # -- optimizer and scheduler optimizer, scaler, scheduler, wd_scheduler = init_opt( @@ -197,7 +194,8 @@ def main(args_eval, resume_preempt=False): iterations_per_epoch=ipe, warmup=warmup, num_epochs=num_epochs, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) classifier = DistributedDataParallel(classifier, static_graph=True) # -- load training checkpoint @@ -208,27 +206,28 @@ def main(args_eval, resume_preempt=False): r_path=latest_path, classifier=classifier, opt=optimizer, - scaler=scaler) - for _ in range(start_epoch*ipe): + scaler=scaler, + ) + for _ in range(start_epoch * ipe): scheduler.step() wd_scheduler.step() def save_checkpoint(epoch): save_dict = { - 'classifier': classifier.state_dict(), - 'opt': optimizer.state_dict(), - 'scaler': None if scaler is None else scaler.state_dict(), - 'epoch': epoch, - 'batch_size': batch_size, - 'world_size': world_size, - 'lr': lr + "classifier": classifier.state_dict(), + "opt": optimizer.state_dict(), + "scaler": None if scaler is None else scaler.state_dict(), + "epoch": epoch, + "batch_size": batch_size, + "world_size": world_size, + "lr": lr, } if rank == 0: torch.save(save_dict, latest_path) # TRAIN LOOP for epoch in range(start_epoch, num_epochs): - logger.info('Epoch %d' % (epoch + 1)) + logger.info("Epoch %d" % (epoch + 1)) train_acc = run_one_epoch( device=device, training=True, @@ -239,7 +238,8 @@ def save_checkpoint(epoch): scheduler=scheduler, wd_scheduler=wd_scheduler, data_loader=train_loader, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) val_acc = run_one_epoch( device=device, @@ -251,9 +251,12 @@ def save_checkpoint(epoch): scheduler=scheduler, wd_scheduler=wd_scheduler, data_loader=val_loader, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) - logger.info('[%5d] train: %.3f%% test: %.3f%%' % (epoch + 1, train_acc, val_acc)) + logger.info( + "[%5d] train: %.3f%% test: %.3f%%" % (epoch + 1, train_acc, val_acc) + ) if rank == 0: csv_logger.log(epoch + 1, train_acc, val_acc) save_checkpoint(epoch + 1) @@ -292,7 +295,7 @@ def run_one_epoch( outputs = classifier(outputs) loss = criterion(outputs, labels) - top1_acc = 100. * outputs.max(dim=1).indices.eq(labels).sum() / len(imgs) + top1_acc = 100.0 * outputs.max(dim=1).indices.eq(labels).sum() / len(imgs) top1_acc = float(AllReduce.apply(top1_acc)) top1_meter.update(top1_acc) @@ -310,68 +313,70 @@ def run_one_epoch( optimizer.zero_grad() if itr % 20 == 0: - logger.info('[%5d] %.3f%% (loss: %.3f) [mem: %.2e]' - % (itr, top1_meter.avg, loss, - torch.cuda.max_memory_allocated() / 1024.**2)) + logger.info( + "[%5d] %.3f%% (loss: %.3f) [mem: %.2e]" + % ( + itr, + top1_meter.avg, + loss, + torch.cuda.max_memory_allocated() / 1024.0**2, + ) + ) return top1_meter.avg -def load_checkpoint( - device, - r_path, - classifier, - opt, - scaler -): +def load_checkpoint(device, r_path, classifier, opt, scaler): try: - checkpoint = torch.load(r_path, map_location=torch.device('cpu')) - epoch = checkpoint['epoch'] + checkpoint = torch.load(r_path, map_location=torch.device("cpu")) + epoch = checkpoint["epoch"] # -- loading encoder - pretrained_dict = checkpoint['classifier'] + pretrained_dict = checkpoint["classifier"] msg = classifier.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained classifier from epoch {epoch} with msg: {msg}') + logger.info(f"loaded pretrained classifier from epoch {epoch} with msg: {msg}") # -- loading optimizer - opt.load_state_dict(checkpoint['opt']) + opt.load_state_dict(checkpoint["opt"]) if scaler is not None: - scaler.load_state_dict(checkpoint['scaler']) - logger.info(f'loaded optimizers from epoch {epoch}') - logger.info(f'read-path: {r_path}') + scaler.load_state_dict(checkpoint["scaler"]) + logger.info(f"loaded optimizers from epoch {epoch}") + logger.info(f"read-path: {r_path}") del checkpoint except Exception as e: - logger.info(f'Encountered exception when loading checkpoint {e}') + logger.info(f"Encountered exception when loading checkpoint {traceback.format_exc}") epoch = 0 return classifier, opt, scaler, epoch -def load_pretrained( - encoder, - pretrained, - checkpoint_key='target_encoder' -): - logger.info(f'Loading pretrained model from {pretrained}') - checkpoint = torch.load(pretrained, map_location='cpu') +def load_pretrained(encoder, pretrained, checkpoint_key="target_encoder"): + logger.info(f"Loading pretrained model from {pretrained}") + checkpoint = torch.load(pretrained, map_location="cpu") try: pretrained_dict = checkpoint[checkpoint_key] except Exception: - pretrained_dict = checkpoint['encoder'] + pretrained_dict = checkpoint["encoder"] - pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()} - pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()} + pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()} + pretrained_dict = { + k.replace("backbone.", ""): v for k, v in pretrained_dict.items() + } for k, v in encoder.state_dict().items(): if k not in pretrained_dict: logger.info(f'key "{k}" could not be found in loaded state dict') elif pretrained_dict[k].shape != v.shape: - logger.info(f'key "{k}" is of different shape in model and loaded state dict') + logger.info( + f'key "{k}" is of different shape in model and loaded state dict' + ) pretrained_dict[k] = v msg = encoder.load_state_dict(pretrained_dict, strict=False) print(encoder) - logger.info(f'loaded pretrained model with msg: {msg}') - logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}') + logger.info(f"loaded pretrained model with msg: {msg}") + logger.info( + f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}' + ) del checkpoint return encoder @@ -385,28 +390,31 @@ def make_dataloader( rank, resolution=224, training=False, - subset_file=None + subset_file=None, ): - normalization = ((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalization = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) if training: - logger.info('implementing auto-agument strategy') + logger.info("implementing auto-agument strategy") transform = timm_make_transforms( input_size=resolution, is_training=training, - auto_augment='original', - interpolation='bicubic', + auto_augment="original", + interpolation="bicubic", re_prob=0.25, - re_mode='pixel', + re_mode="pixel", re_count=1, mean=normalization[0], - std=normalization[1]) + std=normalization[1], + ) else: - transform = transforms.Compose([ - transforms.Resize(size=int(resolution * 256/224)), - transforms.CenterCrop(size=resolution), - transforms.ToTensor(), - transforms.Normalize(normalization[0], normalization[1])]) + transform = transforms.Compose( + [ + transforms.Resize(size=int(resolution * 256 / 224)), + transforms.CenterCrop(size=resolution), + transforms.ToTensor(), + transforms.Normalize(normalization[0], normalization[1]), + ] + ) data_loader, _ = init_data( data=dataset_name, @@ -419,7 +427,8 @@ def make_dataloader( training=training, copy_data=False, drop_last=False, - subset_file=subset_file) + subset_file=subset_file, + ) return data_loader @@ -436,7 +445,7 @@ def init_model( use_SiLU=False, tight_SiLU=True, uniform_power=False, - checkpoint_key='target_encoder' + checkpoint_key="target_encoder", ): encoder = vit.__dict__[model_name]( img_size=crop_size, @@ -449,15 +458,18 @@ def init_model( tight_SiLU=tight_SiLU, ) if frames_per_clip > 1: + def forward_prehook(module, input): input = input[0] # [B, C, H, W] input = input.unsqueeze(2).repeat(1, 1, frames_per_clip, 1, 1) - return (input) + return input encoder.register_forward_pre_hook(forward_prehook) encoder.to(device) - encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key) + encoder = load_pretrained( + encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key + ) return encoder @@ -471,33 +483,42 @@ def init_opt( wd=1e-6, final_wd=1e-6, final_lr=0.0, - use_bfloat16=False + use_bfloat16=False, ): param_groups = [ { - 'params': (p for n, p in classifier.named_parameters() - if ('bias' not in n) and (len(p.shape) != 1)) - }, { - 'params': (p for n, p in classifier.named_parameters() - if ('bias' in n) or (len(p.shape) == 1)), - 'WD_exclude': True, - 'weight_decay': 0 - } + "params": ( + p + for n, p in classifier.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in classifier.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": True, + "weight_decay": 0, + }, ] - logger.info('Using AdamW') + logger.info("Using AdamW") optimizer = torch.optim.AdamW(param_groups) scheduler = WarmupCosineSchedule( optimizer, - warmup_steps=int(warmup*iterations_per_epoch), + warmup_steps=int(warmup * iterations_per_epoch), start_lr=start_lr, ref_lr=ref_lr, final_lr=final_lr, - T_max=int(num_epochs*iterations_per_epoch)) + T_max=int(num_epochs * iterations_per_epoch), + ) wd_scheduler = CosineWDSchedule( optimizer, ref_wd=wd, final_wd=final_wd, - T_max=int(num_epochs*iterations_per_epoch)) + T_max=int(num_epochs * iterations_per_epoch), + ) scaler = torch.cuda.amp.GradScaler() if use_bfloat16 else None return optimizer, scaler, scheduler, wd_scheduler diff --git a/evals/main.py b/evals/main.py index c614edb8..7a4b078e 100644 --- a/evals/main.py +++ b/evals/main.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -18,19 +18,24 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--fname', type=str, - help='name of config file to load', - default='configs.yaml') + "--fname", type=str, help="name of config file to load", default="configs.yaml" +) parser.add_argument( - '--devices', type=str, nargs='+', default=['cuda:0'], - help='which devices to use on local machine') + "--devices", + type=str, + nargs="+", + default=["cuda:0"], + help="which devices to use on local machine", +) def process_main(rank, fname, world_size, devices): import os - os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) + + os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1]) import logging + logging.basicConfig() logger = logging.getLogger() if rank == 0: @@ -38,30 +43,29 @@ def process_main(rank, fname, world_size, devices): else: logger.setLevel(logging.ERROR) - logger.info(f'called-params {fname}') + logger.info(f"called-params {fname}") # Load config params = None - with open(fname, 'r') as y_file: + with open(fname, "r") as y_file: params = yaml.load(y_file, Loader=yaml.FullLoader) - logger.info('loaded params...') + logger.info("loaded params...") pp = pprint.PrettyPrinter(indent=4) pp.pprint(params) # Init distributed (access to comm between GPUS on same machine) world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) - logger.info(f'Running... (rank: {rank}/{world_size})') + logger.info(f"Running... (rank: {rank}/{world_size})") # Launch the eval with loaded config - eval_main(params['eval_name'], args_eval=params) + eval_main(params["eval_name"], args_eval=params) -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() num_gpus = len(args.devices) - mp.set_start_method('spawn') + mp.set_start_method("spawn") for rank in range(num_gpus): mp.Process( - target=process_main, - args=(rank, args.fname, num_gpus, args.devices) + target=process_main, args=(rank, args.fname, num_gpus, args.devices) ).start() diff --git a/evals/main_distributed.py b/evals/main_distributed.py index 1f332a0b..29459223 100644 --- a/evals/main_distributed.py +++ b/evals/main_distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -22,32 +22,33 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--folder', type=str, - help='location to save submitit logs', - default='/fsx-jepa/massran/submitit/') + "--folder", + type=str, + help="location to save submitit logs", + default="/fsx-jepa/massran/submitit/", +) parser.add_argument( - '--exclude', type=str, - help='nodes to exclude from training', - default=None) + "--exclude", type=str, help="nodes to exclude from training", default=None +) parser.add_argument( - '--batch-launch', action='store_true', - help='whether fname points to a file to batch-lauch several config files') + "--batch-launch", + action="store_true", + help="whether fname points to a file to batch-lauch several config files", +) parser.add_argument( - '--fname', type=str, - help='yaml file containing config file names to launch', - default='configs.yaml') -parser.add_argument( - '--partition', type=str, - help='cluster partition to submit jobs on') -parser.add_argument( - '--time', type=int, default=4300, - help='time in minutes to run job') + "--fname", + type=str, + help="yaml file containing config file names to launch", + default="configs.yaml", +) +parser.add_argument("--partition", type=str, help="cluster partition to submit jobs on") +parser.add_argument("--time", type=int, default=4300, help="time in minutes to run job") class Trainer: def __init__(self, args_eval=None, resume_preempt=None): - self.eval_name = args_eval['eval_name'] + self.eval_name = args_eval["eval_name"] self.args_eval = args_eval self.resume_preempt = resume_preempt @@ -56,47 +57,47 @@ def __call__(self): args_eval = self.args_eval resume_preempt = self.resume_preempt - logger.info('loaded eval params...') + logger.info("loaded eval params...") pp = pprint.PrettyPrinter(indent=4) pp.pprint(args_eval) - eval_main( - eval_name, - args_eval=args_eval, - resume_preempt=resume_preempt) + eval_main(eval_name, args_eval=args_eval, resume_preempt=resume_preempt) def checkpoint(self): fb_trainer = Trainer(self.args_eval, True) - return submitit.helpers.DelayedSubmission(fb_trainer,) + return submitit.helpers.DelayedSubmission( + fb_trainer, + ) def launch_evals_with_parsed_args( args_for_evals, submitit_folder, - partition='learnlab,learnfair', + partition="learnlab,learnfair", timeout=4300, nodes=1, tasks_per_node=1, delay_seconds=10, - exclude_nodes=None + exclude_nodes=None, ): if not isinstance(args_for_evals, list): - logger.info(f'Passed in eval-args of type {type(args_for_evals)}') + logger.info(f"Passed in eval-args of type {type(args_for_evals)}") args_for_evals = [args_for_evals] time.sleep(delay_seconds) - logger.info('Launching evaluations in separate jobs...') + logger.info("Launching evaluations in separate jobs...") executor = submitit.AutoExecutor( - folder=os.path.join(submitit_folder, 'job_%j'), - slurm_max_num_timeout=20) + folder=os.path.join(submitit_folder, "job_%j"), slurm_max_num_timeout=20 + ) executor.update_parameters( slurm_partition=partition, - slurm_mem_per_gpu='55G', + slurm_mem_per_gpu="55G", timeout_min=timeout, nodes=nodes, tasks_per_node=tasks_per_node, cpus_per_task=12, - gpus_per_node=tasks_per_node) + gpus_per_node=tasks_per_node, + ) if exclude_nodes is not None: executor.update_parameters(slurm_exclude=exclude_nodes) @@ -105,12 +106,14 @@ def launch_evals_with_parsed_args( with executor.batch(): for ae in args_for_evals: fb_trainer = Trainer(ae) - job = executor.submit(fb_trainer,) + job = executor.submit( + fb_trainer, + ) trainers.append(fb_trainer) jobs.append(job) for job in jobs: - logger.info(f'Launched eval job with id {job.job_id}') + logger.info(f"Launched eval job with id {job.job_id}") def launch_evals(): @@ -124,7 +127,7 @@ def launch_evals(): # -- config, but actually specifies a list of other config files # -- to run in a slurm job array if args.batch_launch: - with open(args.fname, 'r') as y_file: + with open(args.fname, "r") as y_file: config_fnames = yaml.load(y_file, Loader=yaml.FullLoader) # ---------------------------------------------------------------------- # @@ -134,13 +137,13 @@ def launch_evals(): nodes, tasks_per_node = None, None configs = [] for f in config_fnames: - with open(f, 'r') as y_file: + with open(f, "r") as y_file: _params = yaml.load(y_file, Loader=yaml.FullLoader) - nodes = int(_params.get('nodes')) - tasks_per_node = int(_params.get('tasks_per_node')) + nodes = int(_params.get("nodes")) + tasks_per_node = int(_params.get("tasks_per_node")) configs += [_params] - logger.info(f'Loaded {len(configs)} config files') - logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}') + logger.info(f"Loaded {len(configs)} config files") + logger.info(f"Running all jobs with {nodes=} / {tasks_per_node=}") # ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- # @@ -153,10 +156,11 @@ def launch_evals(): timeout=args.time, nodes=nodes, tasks_per_node=tasks_per_node, - exclude_nodes=args.exclude) + exclude_nodes=args.exclude, + ) # ---------------------------------------------------------------------- # -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() launch_evals() diff --git a/evals/scaffold.py b/evals/scaffold.py index c816b874..c87b3ee8 100644 --- a/evals/scaffold.py +++ b/evals/scaffold.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -13,12 +13,8 @@ logger = logging.getLogger() -def main( - eval_name, - args_eval, - resume_preempt=False -): - logger.info(f'Running evaluation: {eval_name}') - return importlib.import_module(f'evals.{eval_name}.eval').main( - args_eval=args_eval, - resume_preempt=resume_preempt) +def main(eval_name, args_eval, resume_preempt=False): + logger.info(f"Running evaluation: {eval_name}") + return importlib.import_module(f"evals.{eval_name}.eval").main( + args_eval=args_eval, resume_preempt=resume_preempt + ) diff --git a/evals/video_classification_frozen/eval.py b/evals/video_classification_frozen/eval.py index f81f526d..c1e4164f 100644 --- a/evals/video_classification_frozen/eval.py +++ b/evals/video_classification_frozen/eval.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -13,12 +13,13 @@ # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE # -- TO EACH PROCESS - os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] except Exception: pass import logging import pprint +import traceback import numpy as np @@ -33,23 +34,17 @@ from src.datasets.data_manager import ( init_data, ) -from src.utils.distributed import ( - init_distributed, - AllReduce -) +from src.utils.distributed import init_distributed, AllReduce from src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule, ) -from src.utils.logging import ( - AverageMeter, - CSVLogger -) +from src.utils.logging import AverageMeter, CSVLogger from evals.video_classification_frozen.utils import ( make_transforms, ClipAggregation, - FrameAggregation + FrameAggregation, ) logging.basicConfig() @@ -71,82 +66,81 @@ def main(args_eval, resume_preempt=False): # ----------------------------------------------------------------------- # # -- PRETRAIN - args_pretrain = args_eval.get('pretrain') - checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder') - model_name = args_pretrain.get('model_name', None) - patch_size = args_pretrain.get('patch_size', None) - pretrain_folder = args_pretrain.get('folder', None) - ckp_fname = args_pretrain.get('checkpoint', None) - tag = args_pretrain.get('write_tag', None) - use_sdpa = args_pretrain.get('use_sdpa', True) - use_SiLU = args_pretrain.get('use_silu', False) - tight_SiLU = args_pretrain.get('tight_silu', True) - uniform_power = args_pretrain.get('uniform_power', False) + args_pretrain = args_eval.get("pretrain") + checkpoint_key = args_pretrain.get("checkpoint_key", "target_encoder") + model_name = args_pretrain.get("model_name", None) + patch_size = args_pretrain.get("patch_size", None) + pretrain_folder = args_pretrain.get("folder", None) + ckp_fname = args_pretrain.get("checkpoint", None) + tag = args_pretrain.get("write_tag", None) + use_sdpa = args_pretrain.get("use_sdpa", True) + use_SiLU = args_pretrain.get("use_silu", False) + tight_SiLU = args_pretrain.get("tight_silu", True) + uniform_power = args_pretrain.get("uniform_power", False) pretrained_path = os.path.join(pretrain_folder, ckp_fname) # Optional [for Video model]: - tubelet_size = args_pretrain.get('tubelet_size', 2) - pretrain_frames_per_clip = args_pretrain.get('frames_per_clip', 1) + tubelet_size = args_pretrain.get("tubelet_size", 2) + pretrain_frames_per_clip = args_pretrain.get("frames_per_clip", 1) # -- DATA - args_data = args_eval.get('data') - train_data_path = [args_data.get('dataset_train')] - val_data_path = [args_data.get('dataset_val')] - dataset_type = args_data.get('dataset_type', 'VideoDataset') - num_classes = args_data.get('num_classes') - eval_num_segments = args_data.get('num_segments', 1) - eval_frames_per_clip = args_data.get('frames_per_clip', 16) - eval_frame_step = args_pretrain.get('frame_step', 4) - eval_duration = args_pretrain.get('clip_duration', None) - eval_num_views_per_segment = args_data.get('num_views_per_segment', 1) + args_data = args_eval.get("data") + train_data_path = [args_data.get("dataset_train")] + val_data_path = [args_data.get("dataset_val")] + dataset_type = args_data.get("dataset_type", "VideoDataset") + num_classes = args_data.get("num_classes") + eval_num_segments = args_data.get("num_segments", 1) + eval_frames_per_clip = args_data.get("frames_per_clip", 16) + eval_frame_step = args_pretrain.get("frame_step", 4) + eval_duration = args_pretrain.get("clip_duration", None) + eval_num_views_per_segment = args_data.get("num_views_per_segment", 1) # -- OPTIMIZATION - args_opt = args_eval.get('optimization') - resolution = args_opt.get('resolution', 224) - batch_size = args_opt.get('batch_size') - attend_across_segments = args_opt.get('attend_across_segments', False) - num_epochs = args_opt.get('num_epochs') - wd = args_opt.get('weight_decay') - start_lr = args_opt.get('start_lr') - lr = args_opt.get('lr') - final_lr = args_opt.get('final_lr') - warmup = args_opt.get('warmup') - use_bfloat16 = args_opt.get('use_bfloat16') + args_opt = args_eval.get("optimization") + resolution = args_opt.get("resolution", 224) + batch_size = args_opt.get("batch_size") + attend_across_segments = args_opt.get("attend_across_segments", False) + num_epochs = args_opt.get("num_epochs") + wd = args_opt.get("weight_decay") + start_lr = args_opt.get("start_lr") + lr = args_opt.get("lr") + final_lr = args_opt.get("final_lr") + warmup = args_opt.get("warmup") + use_bfloat16 = args_opt.get("use_bfloat16") # -- EXPERIMENT-ID/TAG (optional) - resume_checkpoint = args_eval.get('resume_checkpoint', False) or resume_preempt - eval_tag = args_eval.get('tag', None) + resume_checkpoint = args_eval.get("resume_checkpoint", False) or resume_preempt + eval_tag = args_eval.get("tag", None) # ----------------------------------------------------------------------- # try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except Exception: pass if not torch.cuda.is_available(): - device = torch.device('cpu') + device = torch.device("cpu") else: - device = torch.device('cuda:0') + device = torch.device("cuda:0") torch.cuda.set_device(device) world_size, rank = init_distributed() - logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") # -- log/checkpointing paths - folder = os.path.join(pretrain_folder, 'video_classification_frozen/') + folder = os.path.join(pretrain_folder, "video_classification_frozen/") if eval_tag is not None: folder = os.path.join(folder, eval_tag) if not os.path.exists(folder): os.makedirs(folder, exist_ok=True) - log_file = os.path.join(folder, f'{tag}_r{rank}.csv') - latest_path = os.path.join(folder, f'{tag}-latest.pth.tar') + log_file = os.path.join(folder, f"{tag}_r{rank}.csv") + latest_path = os.path.join(folder, f"{tag}-latest.pth.tar") # -- make csv_logger if rank == 0: - csv_logger = CSVLogger(log_file, - ('%d', 'epoch'), - ('%.5f', 'loss'), - ('%.5f', 'acc')) + csv_logger = CSVLogger( + log_file, ("%d", "epoch"), ("%.5f", "loss"), ("%.5f", "acc") + ) # Initialize model @@ -163,7 +157,8 @@ def main(args_eval, resume_preempt=False): checkpoint_key=checkpoint_key, use_SiLU=use_SiLU, tight_SiLU=tight_SiLU, - use_sdpa=use_sdpa) + use_sdpa=use_sdpa, + ) if pretrain_frames_per_clip == 1: # Process each frame independently and aggregate encoder = FrameAggregation(encoder).to(device) @@ -172,7 +167,7 @@ def main(args_eval, resume_preempt=False): encoder = ClipAggregation( encoder, tubelet_size=tubelet_size, - attend_across_segments=attend_across_segments + attend_across_segments=attend_across_segments, ).to(device) encoder.eval() for p in encoder.parameters(): @@ -199,7 +194,8 @@ def main(args_eval, resume_preempt=False): batch_size=batch_size, world_size=world_size, rank=rank, - training=True) + training=True, + ) val_loader = make_dataloader( dataset_type=dataset_type, root_path=val_data_path, @@ -213,9 +209,10 @@ def main(args_eval, resume_preempt=False): batch_size=batch_size, world_size=world_size, rank=rank, - training=False) + training=False, + ) ipe = len(train_loader) - logger.info(f'Dataloader created... iterations per epoch: {ipe}') + logger.info(f"Dataloader created... iterations per epoch: {ipe}") # -- optimizer and scheduler optimizer, scaler, scheduler, wd_scheduler = init_opt( @@ -227,7 +224,8 @@ def main(args_eval, resume_preempt=False): iterations_per_epoch=ipe, warmup=warmup, num_epochs=num_epochs, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) classifier = DistributedDataParallel(classifier, static_graph=True) # -- load training checkpoint @@ -238,27 +236,28 @@ def main(args_eval, resume_preempt=False): r_path=latest_path, classifier=classifier, opt=optimizer, - scaler=scaler) - for _ in range(start_epoch*ipe): + scaler=scaler, + ) + for _ in range(start_epoch * ipe): scheduler.step() wd_scheduler.step() def save_checkpoint(epoch): save_dict = { - 'classifier': classifier.state_dict(), - 'opt': optimizer.state_dict(), - 'scaler': None if scaler is None else scaler.state_dict(), - 'epoch': epoch, - 'batch_size': batch_size, - 'world_size': world_size, - 'lr': lr + "classifier": classifier.state_dict(), + "opt": optimizer.state_dict(), + "scaler": None if scaler is None else scaler.state_dict(), + "epoch": epoch, + "batch_size": batch_size, + "world_size": world_size, + "lr": lr, } if rank == 0: torch.save(save_dict, latest_path) # TRAIN LOOP for epoch in range(start_epoch, num_epochs): - logger.info('Epoch %d' % (epoch + 1)) + logger.info("Epoch %d" % (epoch + 1)) train_acc = run_one_epoch( device=device, training=True, @@ -272,7 +271,8 @@ def save_checkpoint(epoch): scheduler=scheduler, wd_scheduler=wd_scheduler, data_loader=train_loader, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) val_acc = run_one_epoch( device=device, @@ -287,9 +287,12 @@ def save_checkpoint(epoch): scheduler=scheduler, wd_scheduler=wd_scheduler, data_loader=val_loader, - use_bfloat16=use_bfloat16) + use_bfloat16=use_bfloat16, + ) - logger.info('[%5d] train: %.3f%% test: %.3f%%' % (epoch + 1, train_acc, val_acc)) + logger.info( + "[%5d] train: %.3f%% test: %.3f%%" % (epoch + 1, train_acc, val_acc) + ) if rank == 0: csv_logger.log(epoch + 1, train_acc, val_acc) save_checkpoint(epoch + 1) @@ -324,7 +327,9 @@ def run_one_epoch( # Load data and put on GPU clips = [ - [dij.to(device, non_blocking=True) for dij in di] # iterate over spatial views of clip + [ + dij.to(device, non_blocking=True) for dij in di + ] # iterate over spatial views of clip for di in data[0] # iterate over temporal index of clip ] clip_indices = [d.to(device, non_blocking=True) for d in data[2]] @@ -349,13 +354,21 @@ def run_one_epoch( if attend_across_segments: loss = sum([criterion(o, labels) for o in outputs]) / len(outputs) else: - loss = sum([sum([criterion(ost, labels) for ost in os]) for os in outputs]) / len(outputs) / len(outputs[0]) + loss = ( + sum([sum([criterion(ost, labels) for ost in os]) for os in outputs]) + / len(outputs) + / len(outputs[0]) + ) with torch.no_grad(): if attend_across_segments: outputs = sum([F.softmax(o, dim=1) for o in outputs]) / len(outputs) else: - outputs = sum([sum([F.softmax(ost, dim=1) for ost in os]) for os in outputs]) / len(outputs) / len(outputs[0]) - top1_acc = 100. * outputs.max(dim=1).indices.eq(labels).sum() / batch_size + outputs = ( + sum([sum([F.softmax(ost, dim=1) for ost in os]) for os in outputs]) + / len(outputs) + / len(outputs[0]) + ) + top1_acc = 100.0 * outputs.max(dim=1).indices.eq(labels).sum() / batch_size top1_acc = float(AllReduce.apply(top1_acc)) top1_meter.update(top1_acc) @@ -373,68 +386,70 @@ def run_one_epoch( optimizer.zero_grad() if itr % 20 == 0: - logger.info('[%5d] %.3f%% (loss: %.3f) [mem: %.2e]' - % (itr, top1_meter.avg, loss, - torch.cuda.max_memory_allocated() / 1024.**2)) + logger.info( + "[%5d] %.3f%% (loss: %.3f) [mem: %.2e]" + % ( + itr, + top1_meter.avg, + loss, + torch.cuda.max_memory_allocated() / 1024.0**2, + ) + ) return top1_meter.avg -def load_checkpoint( - device, - r_path, - classifier, - opt, - scaler -): +def load_checkpoint(device, r_path, classifier, opt, scaler): try: - checkpoint = torch.load(r_path, map_location=torch.device('cpu')) - epoch = checkpoint['epoch'] + checkpoint = torch.load(r_path, map_location=torch.device("cpu")) + epoch = checkpoint["epoch"] # -- loading encoder - pretrained_dict = checkpoint['classifier'] + pretrained_dict = checkpoint["classifier"] msg = classifier.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained classifier from epoch {epoch} with msg: {msg}') + logger.info(f"loaded pretrained classifier from epoch {epoch} with msg: {msg}") # -- loading optimizer - opt.load_state_dict(checkpoint['opt']) + opt.load_state_dict(checkpoint["opt"]) if scaler is not None: - scaler.load_state_dict(checkpoint['scaler']) - logger.info(f'loaded optimizers from epoch {epoch}') - logger.info(f'read-path: {r_path}') + scaler.load_state_dict(checkpoint["scaler"]) + logger.info(f"loaded optimizers from epoch {epoch}") + logger.info(f"read-path: {r_path}") del checkpoint except Exception as e: - logger.info(f'Encountered exception when loading checkpoint {e}') + logger.info(f"Encountered exception when loading checkpoint {traceback.format_exc}") epoch = 0 return classifier, opt, scaler, epoch -def load_pretrained( - encoder, - pretrained, - checkpoint_key='target_encoder' -): - logger.info(f'Loading pretrained model from {pretrained}') - checkpoint = torch.load(pretrained, map_location='cpu') +def load_pretrained(encoder, pretrained, checkpoint_key="target_encoder"): + logger.info(f"Loading pretrained model from {pretrained}") + checkpoint = torch.load(pretrained, map_location="cpu") try: pretrained_dict = checkpoint[checkpoint_key] except Exception: - pretrained_dict = checkpoint['encoder'] + pretrained_dict = checkpoint["encoder"] - pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()} - pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()} + pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()} + pretrained_dict = { + k.replace("backbone.", ""): v for k, v in pretrained_dict.items() + } for k, v in encoder.state_dict().items(): if k not in pretrained_dict: logger.info(f'key "{k}" could not be found in loaded state dict') elif pretrained_dict[k].shape != v.shape: - logger.info(f'key "{k}" is of different shape in model and loaded state dict') + logger.info( + f'key "{k}" is of different shape in model and loaded state dict' + ) pretrained_dict[k] = v msg = encoder.load_state_dict(pretrained_dict, strict=False) print(encoder) - logger.info(f'loaded pretrained model with msg: {msg}') - logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}') + logger.info(f"loaded pretrained model with msg: {msg}") + logger.info( + f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}' + ) del checkpoint return encoder @@ -444,7 +459,7 @@ def make_dataloader( batch_size, world_size, rank, - dataset_type='VideoDataset', + dataset_type="VideoDataset", resolution=224, frames_per_clip=16, frame_step=4, @@ -454,14 +469,14 @@ def make_dataloader( allow_segment_overlap=True, training=False, num_workers=12, - subset_file=None + subset_file=None, ): # Make Video Transforms transform = make_transforms( training=training, num_views_per_clip=num_views_per_segment, random_horizontal_flip=False, - random_resize_aspect_ratio=(0.75, 4/3), + random_resize_aspect_ratio=(0.75, 4 / 3), random_resize_scale=(0.08, 1.0), reprob=0.25, auto_augment=True, @@ -484,7 +499,8 @@ def make_dataloader( num_workers=num_workers, copy_data=False, drop_last=False, - subset_file=subset_file) + subset_file=subset_file, + ) return data_loader @@ -501,7 +517,7 @@ def init_model( use_SiLU=False, tight_SiLU=True, uniform_power=False, - checkpoint_key='target_encoder' + checkpoint_key="target_encoder", ): encoder = vit.__dict__[model_name]( img_size=crop_size, @@ -515,7 +531,9 @@ def init_model( ) encoder.to(device) - encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key) + encoder = load_pretrained( + encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key + ) return encoder @@ -529,33 +547,42 @@ def init_opt( wd=1e-6, final_wd=1e-6, final_lr=0.0, - use_bfloat16=False + use_bfloat16=False, ): param_groups = [ { - 'params': (p for n, p in classifier.named_parameters() - if ('bias' not in n) and (len(p.shape) != 1)) - }, { - 'params': (p for n, p in classifier.named_parameters() - if ('bias' in n) or (len(p.shape) == 1)), - 'WD_exclude': True, - 'weight_decay': 0 - } + "params": ( + p + for n, p in classifier.named_parameters() + if ("bias" not in n) and (len(p.shape) != 1) + ) + }, + { + "params": ( + p + for n, p in classifier.named_parameters() + if ("bias" in n) or (len(p.shape) == 1) + ), + "WD_exclude": True, + "weight_decay": 0, + }, ] - logger.info('Using AdamW') + logger.info("Using AdamW") optimizer = torch.optim.AdamW(param_groups) scheduler = WarmupCosineSchedule( optimizer, - warmup_steps=int(warmup*iterations_per_epoch), + warmup_steps=int(warmup * iterations_per_epoch), start_lr=start_lr, ref_lr=ref_lr, final_lr=final_lr, - T_max=int(num_epochs*iterations_per_epoch)) + T_max=int(num_epochs * iterations_per_epoch), + ) wd_scheduler = CosineWDSchedule( optimizer, ref_wd=wd, final_wd=final_wd, - T_max=int(num_epochs*iterations_per_epoch)) + T_max=int(num_epochs * iterations_per_epoch), + ) scaler = torch.cuda.amp.GradScaler() if use_bfloat16 else None return optimizer, scaler, scheduler, wd_scheduler diff --git a/evals/video_classification_frozen/utils.py b/evals/video_classification_frozen/utils.py index 450f799a..c1975dde 100644 --- a/evals/video_classification_frozen/utils.py +++ b/evals/video_classification_frozen/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -26,11 +26,7 @@ class FrameAggregation(nn.Module): """ def __init__( - self, - model, - max_frames=10000, - use_pos_embed=False, - attend_across_segments=False + self, model, max_frames=10000, use_pos_embed=False, attend_across_segments=False ): super().__init__() self.model = model @@ -41,8 +37,8 @@ def __init__( self.pos_embed = None if use_pos_embed: self.pos_embed = nn.Parameter( - torch.zeros(1, max_frames, embed_dim), - requires_grad=False) + torch.zeros(1, max_frames, embed_dim), requires_grad=False + ) sincos = get_1d_sincos_pos_embed(embed_dim, max_frames) self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) @@ -59,7 +55,7 @@ def forward(self, x, clip_indices=None): B, C, T, H, W = x.size() # Put each frame along the batch dimension - x = x.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W) + x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) outputs = self.model(x) _, N, D = outputs.size() @@ -69,13 +65,19 @@ def forward(self, x, clip_indices=None): B = B // num_views_per_clip all_outputs = [] for i in range(num_views_per_clip): - o = outputs[i*B:(i+1)*B] + o = outputs[i * B : (i + 1) * B] # Compute positional embedding if (self.pos_embed is not None) and (clip_indices is not None): pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] - pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) - pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension - pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] + pos_embed = apply_masks( + pos_embed, clip_indices, concat=False + ) # list(Tensor([B, T, D])) + pos_embed = torch.cat( + pos_embed, dim=1 + ) # concatenate along temporal dimension + pos_embed = pos_embed.unsqueeze(2).repeat( + 1, 1, N, 1 + ) # [B, T*num_clips, N, D] pos_embed = pos_embed.flatten(1, 2) o += pos_embed all_outputs += [o] @@ -94,7 +96,7 @@ def __init__( tubelet_size=2, max_frames=10000, use_pos_embed=False, - attend_across_segments=False + attend_across_segments=False, ): super().__init__() self.model = model @@ -107,8 +109,8 @@ def __init__( if use_pos_embed: max_T = max_frames // tubelet_size self.pos_embed = nn.Parameter( - torch.zeros(1, max_T, embed_dim), - requires_grad=False) + torch.zeros(1, max_T, embed_dim), requires_grad=False + ) sincos = get_1d_sincos_pos_embed(embed_dim, max_T) self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) @@ -131,9 +133,9 @@ def forward(self, x, clip_indices=None): eff_B = B * num_views_per_clip all_outputs = [[] for _ in range(num_views_per_clip)] for i in range(num_clips): - o = outputs[i*eff_B:(i+1)*eff_B] + o = outputs[i * eff_B : (i + 1) * eff_B] for j in range(num_views_per_clip): - all_outputs[j].append(o[j*B:(j+1)*B]) + all_outputs[j].append(o[j * B : (j + 1) * B]) if not self.attend_across_segments: return all_outputs @@ -146,11 +148,17 @@ def forward(self, x, clip_indices=None): # Compute positional embedding if (self.pos_embed is not None) and (clip_indices is not None): - clip_indices = [c[:, ::self.tubelet_size] for c in clip_indices] + clip_indices = [c[:, :: self.tubelet_size] for c in clip_indices] pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] - pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) - pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension - pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] + pos_embed = apply_masks( + pos_embed, clip_indices, concat=False + ) # list(Tensor([B, T, D])) + pos_embed = torch.cat( + pos_embed, dim=1 + ) # concatenate along temporal dimension + pos_embed = pos_embed.unsqueeze(2).repeat( + 1, 1, N, 1 + ) # [B, T*num_clips, N, D] pos_embed = pos_embed.flatten(1, 2) outputs += pos_embed @@ -162,19 +170,18 @@ def forward(self, x, clip_indices=None): def make_transforms( training=True, random_horizontal_flip=True, - random_resize_aspect_ratio=(3/4, 4/3), + random_resize_aspect_ratio=(3 / 4, 4 / 3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, num_views_per_clip=1, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): if not training and num_views_per_clip > 1: - print('Making EvalVideoTransform, multi-view') + print("Making EvalVideoTransform, multi-view") _frames_augmentation = EvalVideoTransform( num_views_per_clip=num_views_per_clip, short_side_size=crop_size, @@ -202,25 +209,26 @@ def __init__( self, training=True, random_horizontal_flip=True, - random_resize_aspect_ratio=(3/4, 4/3), + random_resize_aspect_ratio=(3 / 4, 4 / 3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): self.training = training short_side_size = int(crop_size * 256 / 224) - self.eval_transform = video_transforms.Compose([ - video_transforms.Resize(short_side_size, interpolation='bilinear'), - video_transforms.CenterCrop(size=(crop_size, crop_size)), - volume_transforms.ClipToTensor(), - video_transforms.Normalize(mean=normalize[0], std=normalize[1]) - ]) + self.eval_transform = video_transforms.Compose( + [ + video_transforms.Resize(short_side_size, interpolation="bilinear"), + video_transforms.CenterCrop(size=(crop_size, crop_size)), + volume_transforms.ClipToTensor(), + video_transforms.Normalize(mean=normalize[0], std=normalize[1]), + ] + ) self.random_horizontal_flip = random_horizontal_flip self.random_resize_aspect_ratio = random_resize_aspect_ratio @@ -232,20 +240,23 @@ def __init__( self.autoaug_transform = video_transforms.create_random_augment( input_size=(crop_size, crop_size), - auto_augment='rand-m7-n4-mstd0.5-inc1', - interpolation='bicubic', + auto_augment="rand-m7-n4-mstd0.5-inc1", + interpolation="bicubic", ) - self.spatial_transform = video_transforms.random_resized_crop_with_shift \ - if motion_shift else video_transforms.random_resized_crop + self.spatial_transform = ( + video_transforms.random_resized_crop_with_shift + if motion_shift + else video_transforms.random_resized_crop + ) self.reprob = reprob self.erase_transform = RandomErasing( reprob, - mode='pixel', + mode="pixel", max_count=1, num_splits=1, - device='cpu', + device="cpu", ) def __call__(self, buffer): @@ -289,16 +300,19 @@ def __init__( self, num_views_per_clip=1, short_side_size=224, - normalize=((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ): self.views_per_clip = num_views_per_clip self.short_side_size = short_side_size - self.spatial_resize = video_transforms.Resize(short_side_size, interpolation='bilinear') - self.to_tensor = video_transforms.Compose([ - volume_transforms.ClipToTensor(), - video_transforms.Normalize(mean=normalize[0], std=normalize[1]) - ]) + self.spatial_resize = video_transforms.Resize( + short_side_size, interpolation="bilinear" + ) + self.to_tensor = video_transforms.Compose( + [ + volume_transforms.ClipToTensor(), + video_transforms.Normalize(mean=normalize[0], std=normalize[1]), + ] + ) def __call__(self, buffer): @@ -312,11 +326,11 @@ def __call__(self, buffer): all_views = [] for i in range(num_views): - start = i*spatial_step + start = i * spatial_step if H > W: - view = buffer[:, start:start+side_len, :, :] + view = buffer[:, start : start + side_len, :, :] else: - view = buffer[:, :, start:start+side_len, :] + view = buffer[:, :, start : start + side_len, :] view = self.to_tensor(view) all_views.append(view) diff --git a/logs_and_checkpoints/.gitkeep b/logs_and_checkpoints/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/logs_and_checkpoints/jepa_r0.csv b/logs_and_checkpoints/jepa_r0.csv new file mode 100644 index 00000000..eb59245a --- /dev/null +++ b/logs_and_checkpoints/jepa_r0.csv @@ -0,0 +1,45 @@ +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) +epoch,itr,loss,loss-jepa,reg-loss,enc-grad-norm,pred-grad-norm,gpu-time(ms),wall-time(ms) diff --git a/logs_and_checkpoints/params-pretrain.yaml b/logs_and_checkpoints/params-pretrain.yaml new file mode 100644 index 00000000..fdc80bd0 --- /dev/null +++ b/logs_and_checkpoints/params-pretrain.yaml @@ -0,0 +1,88 @@ +app: vjepa +data: + batch_size: 10 + clip_duration: null + crop_size: 384 + dataset_type: egovehicle_imagedataset + datasets: + - /home/ncdev/Documents/darwin/data/raw/ + decode_one_clip: true + filter_short_videos: false + num_clips: 1 + num_frames: 16 + num_workers: 12 + patch_size: 16 + pin_mem: true + sampling_rate: 4 + tubelet_size: 2 +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 +logging: + folder: /home/ncdev/Documents/darwin/jepa/logs_and_checkpoints + write_tag: jepa +loss: + loss_exp: 1.0 + reg_coeff: 0.0 +mask: +- aspect_ratio: + - 0.75 + - 1.5 + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: false + read_checkpoint: null + seed: 234 + use_sdpa: true +model: + model_name: vit_huge + pred_depth: 12 + pred_embed_dim: 384 + uniform_power: true + use_mask_tokens: true + zero_init_mask_tokens: true +nodes: 30 +optimization: + clip_grad: 10.0 + ema: + - 0.998 + - 1.0 + epochs: 300 + final_lr: 1.0e-06 + final_weight_decay: 0.4 + ipe: 300 + ipe_scale: 1.25 + lr: 0.000625 + start_lr: 0.0002 + warmup: 40 + weight_decay: 0.04 +tasks_per_node: 8 diff --git a/setup.py b/setup.py index 82de1e0c..c987138b 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. @@ -9,6 +9,7 @@ VERSION = "0.0.1" + def get_requirements(): with open("./requirements.txt") as reqsf: reqs = reqsf.readlines() diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datasets/data_manager.py b/src/datasets/data_manager.py index cdb7ade4..0adff499 100644 --- a/src/datasets/data_manager.py +++ b/src/datasets/data_manager.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -16,7 +16,7 @@ def init_data( batch_size, transform=None, shared_transform=None, - data='ImageNet', + data="ImageNet", collator=None, pin_mem=True, num_workers=8, @@ -45,10 +45,13 @@ def init_data( log_dir=None, ): - if (data.lower() == 'imagenet') \ - or (data.lower() == 'inat21') \ - or (data.lower() == 'places205'): + if ( + (data.lower() == "imagenet") + or (data.lower() == "inat21") + or (data.lower() == "places205") + ): from src.datasets.image_dataset import make_imagedataset + dataset, data_loader, dist_sampler = make_imagedataset( transform=transform, batch_size=batch_size, @@ -63,10 +66,12 @@ def init_data( persistent_workers=persistent_workers, copy_data=copy_data, drop_last=drop_last, - subset_file=subset_file) + subset_file=subset_file, + ) - elif data.lower() == 'videodataset': + elif data.lower() == "videodataset": from src.datasets.video_dataset import make_videodataset + dataset, data_loader, dist_sampler = make_videodataset( data_paths=root_path, batch_size=batch_size, @@ -86,6 +91,21 @@ def init_data( world_size=world_size, rank=rank, drop_last=drop_last, - log_dir=log_dir) - + log_dir=log_dir, + ) + elif data.lower() == "egovehicle_imagedataset": + from src.datasets.image_dataset import make_egovehicle_imagedataset + + dataset, data_loader, dist_sampler = make_egovehicle_imagedataset( + data_dir=root_path[0], + batch_size=batch_size, + transform=transform, + shared_transform=shared_transform, + mask_collator=collator, + num_workers=num_workers, + world_size=world_size, + rank=rank, + pin_mem=pin_mem, + drop_last=drop_last, + ) return (data_loader, dist_sampler) diff --git a/src/datasets/image_dataset.py b/src/datasets/image_dataset.py index 84e9b082..fe6c7b55 100644 --- a/src/datasets/image_dataset.py +++ b/src/datasets/image_dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -6,11 +6,18 @@ # import os +import PIL +from collections import defaultdict + from logging import getLogger import torch import torchvision +from datetime import datetime +import numpy as np +from torch.utils.data import DataLoader, DistributedSampler, Sampler + _GLOBAL_SEED = 0 logger = getLogger() @@ -21,7 +28,7 @@ class ImageFolder(torchvision.datasets.ImageFolder): def __init__( self, root, - image_folder='imagenet_full_size/061417/', + image_folder="imagenet_full_size/061417/", transform=None, train=True, ): @@ -32,11 +39,11 @@ def __init__( :param train: whether to load train data (or validation) """ - suffix = 'train/' if train else 'val/' + suffix = "train/" if train else "val/" data_path = os.path.join(root, image_folder, suffix) - logger.info(f'data-path {data_path}') + logger.info(f"data-path {data_path}") super(ImageFolder, self).__init__(root=data_path, transform=transform) - logger.info('Initialized ImageFolder') + logger.info("Initialized ImageFolder") def make_imagedataset( @@ -53,18 +60,15 @@ def make_imagedataset( copy_data=False, drop_last=True, persistent_workers=False, - subset_file=None + subset_file=None, ): dataset = ImageFolder( - root=root_path, - image_folder=image_folder, - transform=transform, - train=training) - logger.info('ImageFolder dataset created') + root=root_path, image_folder=image_folder, transform=transform, train=training + ) + logger.info("ImageFolder dataset created") dist_sampler = torch.utils.data.distributed.DistributedSampler( - dataset=dataset, - num_replicas=world_size, - rank=rank) + dataset=dataset, num_replicas=world_size, rank=rank + ) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=collator, @@ -73,7 +77,196 @@ def make_imagedataset( drop_last=drop_last, pin_memory=pin_mem, num_workers=num_workers, - persistent_workers=persistent_workers) - logger.info('ImageFolder unsupervised data loader created') + persistent_workers=persistent_workers, + ) + logger.info("ImageFolder unsupervised data loader created") + + return dataset, data_loader, dist_sampler + + +import os +import pandas as pd +import torch +from PIL import Image + + +class ImageDataset(torch.utils.data.Dataset): + def __init__(self, data_dir, transform=None, shared_transform=None): + self.data_dir = data_dir + self.transform = transform + self.shared_transform = shared_transform + + # Load data from drive folders + self.samples = [] + self.drive_data = {} + + try: + drive_folders = [f for f in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, f))] + for drive_folder in drive_folders: + drive_path = os.path.join(data_dir, drive_folder) + csv_file = os.path.join(drive_path, "drive_data.csv") + + if not os.path.exists(csv_file): + logger.warning(f"Skipping drive folder '{drive_folder}' due to missing drive_data.csv file.") + continue + + try: + drive_df = pd.read_csv(csv_file) + self.drive_data[drive_folder] = drive_df + drive_samples = [(os.path.join(drive_path, row['path_to_image']), row['maneuverID']) for _, row in drive_df.iterrows()] + self.samples.extend(drive_samples) + except (pd.errors.EmptyDataError, KeyError) as e: + logger.warning(f"Skipping drive folder '{drive_folder}' due to error: {str(e)}") + + if len(self.samples) == 0: + raise RuntimeError(f"No valid drive folders found in the dataset directory: {data_dir}") + + except OSError as e: + raise RuntimeError(f"Error accessing dataset directory: {data_dir}. Exception: {str(e)}") + + def __getitem__(self, index): + try: + image_path, maneuver_id = self.samples[index] + + # Load image + try: + image = Image.open(image_path).convert("RGB") # Convert to RGB here + except (IOError, PIL.UnidentifiedImageError) as e: + logger.warning(f"Error loading image: {image_path}. Exception: {str(e)}") + raise e + + # Apply transforms + if self.shared_transform is not None: + image = self.shared_transform(image) + if self.transform is not None: + image = self.transform(image) + + return image, maneuver_id + + except IndexError as e: + raise IndexError(f"Index {index} is out of bounds for the dataset.") + + def __len__(self): + if not self.samples: + raise RuntimeError("Dataset is empty. No valid samples found.") + return len(self.samples) + +class SequentialImageSampler(Sampler): + def __init__(self, image_dataset, num_replicas=None, rank=None): + super().__init__(image_dataset) + self.image_dataset = image_dataset + self.num_replicas = num_replicas + self.rank = rank + self.grouped_images = self.group_images_by_folder() + + def group_images_by_folder(self): + # Group image paths by folder, sorting by timestamp within each folder + grouped_images = {} + for folder_path, image_filename in self.image_dataset.samples: + grouped_images.setdefault(folder_path, []).append(image_filename) + for folder_path in grouped_images: + grouped_images[folder_path] = sorted(grouped_images[folder_path], key=self.image_dataset.extract_timestamp_from_filename) + return grouped_images + + def __iter__(self): + # Determine which folders this worker should handle + worker_folders = [ + folder + for i, folder in enumerate(sorted(self.grouped_images.keys())) + if i % self.num_replicas == self.rank + ] + + # Yield image indices in sequential order for each assigned folder + for folder_path in worker_folders: + for image_filename in self.grouped_images[folder_path]: + yield self.image_dataset.samples.index((folder_path, image_filename)) + + def __len__(self): + # Total number of samples across all workers + total_samples = sum(len(images) for images in self.grouped_images.values()) + # Number of samples for this worker + num_samples_per_worker = total_samples // self.num_replicas + # Add any remaining samples to the last worker + if self.rank == self.num_replicas - 1: + num_samples_per_worker += total_samples % self.num_replicas + return num_samples_per_worker + +def collate_fn(batch): + images, maneuvers = zip(*batch) + + # Stack images into a single tensor + images = torch.stack(images, dim=0) + + # Convert maneuvers to a tensor + maneuvers = torch.tensor([m for m in maneuvers]) + + return images, maneuvers + +class SequentialDriveSampler(Sampler): + def __init__(self, image_dataset): + self.image_dataset = image_dataset + self.drive_indices = self._get_drive_indices() + + def _get_drive_indices(self): + drive_indices = defaultdict(list) + for idx, (image_path, _) in enumerate(self.image_dataset.samples): + drive_folder = os.path.basename(os.path.dirname(image_path)) + drive_indices[drive_folder].append(idx) + return drive_indices + + def __iter__(self): + for drive_folder, indices in self.drive_indices.items(): + yield from indices + + def __len__(self): + return len(self.image_dataset) + +def make_egovehicle_imagedataset( + data_dir, + batch_size, + transform=None, + shared_transform=None, + mask_collator=None, + num_workers=10, + rank=0, + world_size=1, + pin_mem=True, + drop_last=True, +): + dataset = ImageDataset( + data_dir=data_dir, + transform=transform, + shared_transform=shared_transform, + ) + + logger.info("ImageDataset created") + + # sampler = SequentialDriveSampler(dataset) + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=True + ) + + # data_loader = DataLoader( + # dataset, + # batch_size=batch_size, + # sampler=dist_sampler, + # collate_fn=mask_collator, + # num_workers=num_workers, + # pin_memory=pin_mem, + # drop_last=True, + # ) + + data_loader = DataLoader( + dataset, + collate_fn=mask_collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=num_workers > 0, + ) + + logger.info("ImageDataset data loader created") return dataset, data_loader, dist_sampler diff --git a/src/datasets/utils/__init__.py b/src/datasets/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datasets/utils/video/__init__.py b/src/datasets/utils/video/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/datasets/utils/video/functional.py b/src/datasets/utils/video/functional.py index a91d15d2..3136cab8 100644 --- a/src/datasets/utils/video/functional.py +++ b/src/datasets/utils/video/functional.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -18,56 +18,54 @@ def _is_tensor_clip(clip): def crop_clip(clip, min_h, min_w, h, w): if isinstance(clip[0], np.ndarray): - cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip] elif isinstance(clip[0], PIL.Image.Image): - cropped = [ - img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip - ] + cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip] else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) return cropped -def resize_clip(clip, size, interpolation='bilinear'): +def resize_clip(clip, size, interpolation="bilinear"): if isinstance(clip[0], np.ndarray): if isinstance(size, numbers.Number): im_h, im_w, im_c = clip[0].shape # Min spatial dim already matches minimal size - if (im_w <= im_h and im_w == size) or (im_h <= im_w - and im_h == size): + if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[0], size[1] - if interpolation == 'bilinear': + if interpolation == "bilinear": np_inter = cv2.INTER_LINEAR else: np_inter = cv2.INTER_NEAREST - scaled = [ - cv2.resize(img, size, interpolation=np_inter) for img in clip - ] + scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip] elif isinstance(clip[0], PIL.Image.Image): if isinstance(size, numbers.Number): im_w, im_h = clip[0].size # Min spatial dim already matches minimal size - if (im_w <= im_h and im_w == size) or (im_h <= im_w - and im_h == size): + if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[1], size[0] - if interpolation == 'bilinear': + if interpolation == "bilinear": pil_inter = PIL.Image.BILINEAR else: pil_inter = PIL.Image.NEAREST scaled = [img.resize(size, pil_inter) for img in clip] else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) return scaled @@ -83,7 +81,7 @@ def get_resize_sizes(im_h, im_w, size): def normalize(clip, mean, std, inplace=False): if not _is_tensor_clip(clip): - raise TypeError('tensor is not a torch clip.') + raise TypeError("tensor is not a torch clip.") if not inplace: clip = clip.clone() diff --git a/src/datasets/utils/video/randaugment.py b/src/datasets/utils/video/randaugment.py index 4c80a990..3837b896 100644 --- a/src/datasets/utils/video/randaugment.py +++ b/src/datasets/utils/video/randaugment.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -50,46 +50,34 @@ def _check_args_tf(kwargs): def shear_x(img, factor, **kwargs): _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) def shear_y(img, factor, **kwargs): _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) def translate_x_rel(img, pct, **kwargs): pixels = pct * img.size[0] _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) def translate_y_rel(img, pct, **kwargs): pixels = pct * img.size[1] _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) def translate_x_abs(img, pixels, **kwargs): _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) def translate_y_abs(img, pixels, **kwargs): _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs - ) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) def rotate(img, degrees, **kwargs): @@ -334,12 +322,12 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None): self.magnitude = magnitude self.hparams = hparams.copy() self.kwargs = { - "fillcolor": hparams["img_mean"] - if "img_mean" in hparams - else _FILL, - "resample": hparams["interpolation"] - if "interpolation" in hparams - else _RANDOM_INTERPOLATION, + "fillcolor": hparams["img_mean"] if "img_mean" in hparams else _FILL, + "resample": ( + hparams["interpolation"] + if "interpolation" in hparams + else _RANDOM_INTERPOLATION + ), } # If magnitude_std is > 0, we introduce some randomness @@ -356,15 +344,11 @@ def __call__(self, img_list): magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range level_args = ( - self.level_fn(magnitude, self.hparams) - if self.level_fn is not None - else () + self.level_fn(magnitude, self.hparams) if self.level_fn is not None else () ) if isinstance(img_list, list): - return [ - self.aug_fn(img, *level_args, **self.kwargs) for img in img_list - ] + return [self.aug_fn(img, *level_args, **self.kwargs) for img in img_list] else: return self.aug_fn(img_list, *level_args, **self.kwargs) @@ -512,7 +496,5 @@ def rand_augment_transform(config_str, hparams): ra_ops = rand_augment_ops( magnitude=magnitude, hparams=hparams, transforms=transforms ) - choice_weights = ( - None if weight_idx is None else _select_rand_weights(weight_idx) - ) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/src/datasets/utils/video/randerase.py b/src/datasets/utils/video/randerase.py index d1f185c8..b38602ea 100644 --- a/src/datasets/utils/video/randerase.py +++ b/src/datasets/utils/video/randerase.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -15,18 +15,14 @@ import torch -def _get_pixels( - per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" -): +def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"): # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # paths, flip the order so normal is run on CPU if this becomes a problem # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 if per_pixel: return torch.empty(patch_size, dtype=dtype, device=device).normal_() elif rand_color: - return torch.empty( - (patch_size[0], 1, 1), dtype=dtype, device=device - ).normal_() + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() else: return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) @@ -104,7 +100,7 @@ def _erase(self, img, chan, img_h, img_w, dtype): if w < img_w and h < img_h: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) - img[:, top:top + h, left:left + w] = _get_pixels( + img[:, top : top + h, left : left + w] = _get_pixels( self.per_pixel, self.rand_color, (chan, h, w), @@ -144,9 +140,7 @@ def _erase_cube( left = random.randint(0, img_w - w) for i in range(batch_start, batch_size): img_instance = img[i] - img_instance[ - :, top:top + h, left:left + w - ] = _get_pixels( + img_instance[:, top : top + h, left : left + w] = _get_pixels( self.per_pixel, self.rand_color, (chan, h, w), @@ -161,9 +155,7 @@ def __call__(self, input): else: batch_size, chan, img_h, img_w = input.size() # skip first slice of batch if num_splits is set (for clean portion of samples) - batch_start = ( - batch_size // self.num_splits if self.num_splits > 1 else 0 - ) + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 if self.cube: self._erase_cube( input, diff --git a/src/datasets/utils/video/transforms.py b/src/datasets/utils/video/transforms.py index ffa8e61d..2af9c69b 100644 --- a/src/datasets/utils/video/transforms.py +++ b/src/datasets/utils/video/transforms.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -22,12 +22,12 @@ _pil_interpolation_to_str = { - Image.NEAREST: 'PIL.Image.NEAREST', - Image.BILINEAR: 'PIL.Image.BILINEAR', - Image.BICUBIC: 'PIL.Image.BICUBIC', - Image.LANCZOS: 'PIL.Image.LANCZOS', - Image.HAMMING: 'PIL.Image.HAMMING', - Image.BOX: 'PIL.Image.BOX', + Image.NEAREST: "PIL.Image.NEAREST", + Image.BILINEAR: "PIL.Image.BILINEAR", + Image.BICUBIC: "PIL.Image.BICUBIC", + Image.LANCZOS: "PIL.Image.LANCZOS", + Image.HAMMING: "PIL.Image.HAMMING", + Image.BOX: "PIL.Image.BOX", } @@ -35,11 +35,11 @@ def _pil_interp(method): - if method == 'bicubic': + if method == "bicubic": return Image.BICUBIC - elif method == 'lanczos': + elif method == "lanczos": return Image.LANCZOS - elif method == 'hamming': + elif method == "hamming": return Image.HAMMING else: return Image.BILINEAR @@ -68,17 +68,13 @@ def random_short_side_scale_jitter( `num boxes` x 4. """ if inverse_uniform_sampling: - size = int( - round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) - ) + size = int(round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))) else: size = int(round(np.random.uniform(min_size, max_size))) height = images.shape[2] width = images.shape[3] - if (width <= height and width == size) or ( - height <= width and height == size - ): + if (width <= height and width == size) or (height <= width and height == size): return images, boxes new_width = size new_height = size @@ -95,7 +91,7 @@ def random_short_side_scale_jitter( torch.nn.functional.interpolate( images, size=(new_height, new_width), - mode='bilinear', + mode="bilinear", align_corners=False, ), boxes, @@ -146,13 +142,9 @@ def random_crop(images, size, boxes=None): x_offset = 0 if width > size: x_offset = int(np.random.randint(0, width - size)) - cropped = images[ - :, :, y_offset:y_offset + size, x_offset:x_offset + size - ] + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] - cropped_boxes = ( - crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None - ) + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None return cropped, cropped_boxes @@ -227,7 +219,7 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): images = torch.nn.functional.interpolate( images, size=(height, width), - mode='bilinear', + mode="bilinear", align_corners=False, ) @@ -244,12 +236,8 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): x_offset = 0 elif spatial_idx == 2: x_offset = width - size - cropped = images[ - :, :, y_offset:y_offset + size, x_offset:x_offset + size - ] - cropped_boxes = ( - crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None - ) + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None if ndim == 3: cropped = cropped.squeeze(0) return cropped, cropped_boxes @@ -306,9 +294,7 @@ def grayscale(images): """ # R -> 0.299, G -> 0.587, B -> 0.114. img_gray = torch.tensor(images) - gray_channel = ( - 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] - ) + gray_channel = 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] img_gray[:, 0] = gray_channel img_gray[:, 1] = gray_channel img_gray[:, 2] = gray_channel @@ -332,20 +318,20 @@ def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): jitter = [] if img_brightness != 0: - jitter.append('brightness') + jitter.append("brightness") if img_contrast != 0: - jitter.append('contrast') + jitter.append("contrast") if img_saturation != 0: - jitter.append('saturation') + jitter.append("saturation") if len(jitter) > 0: order = np.random.permutation(np.arange(len(jitter))) for idx in range(0, len(jitter)): - if jitter[order[idx]] == 'brightness': + if jitter[order[idx]] == "brightness": images = brightness_jitter(img_brightness, images) - elif jitter[order[idx]] == 'contrast': + elif jitter[order[idx]] == "contrast": images = contrast_jitter(img_contrast, images) - elif jitter[order[idx]] == 'saturation': + elif jitter[order[idx]] == "saturation": images = saturation_jitter(img_saturation, images) return images @@ -439,7 +425,7 @@ def lighting_jitter(images, alphastd, eigval, eigvec): # T C H W channel_dim = 1 else: - raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") for idx in range(images.shape[channel_dim]): # C H W @@ -449,9 +435,7 @@ def lighting_jitter(images, alphastd, eigval, eigvec): elif len(images.shape) == 4: out_images[:, idx] = images[:, idx] + rgb[2 - idx] else: - raise NotImplementedError( - f'Unsupported dimension {len(images.shape)}' - ) + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") return out_images @@ -470,21 +454,13 @@ def color_normalization(images, mean, stddev): `num frames` x `channel` x `height` x `width`. """ if len(images.shape) == 3: - assert ( - len(mean) == images.shape[0] - ), 'channel mean not computed properly' - assert ( - len(stddev) == images.shape[0] - ), 'channel stddev not computed properly' + assert len(mean) == images.shape[0], "channel mean not computed properly" + assert len(stddev) == images.shape[0], "channel stddev not computed properly" elif len(images.shape) == 4: - assert ( - len(mean) == images.shape[1] - ), 'channel mean not computed properly' - assert ( - len(stddev) == images.shape[1] - ), 'channel stddev not computed properly' + assert len(mean) == images.shape[1], "channel mean not computed properly" + assert len(stddev) == images.shape[1], "channel stddev not computed properly" else: - raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") out_images = torch.zeros_like(images) for idx in range(len(mean)): @@ -494,9 +470,7 @@ def color_normalization(images, mean, stddev): elif len(images.shape) == 4: out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] else: - raise NotImplementedError( - f'Unsupported dimension {len(images.shape)}' - ) + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") return out_images @@ -568,11 +542,11 @@ def random_resized_crop( width = images.shape[3] i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) - cropped = images[:, :, i:i + h, j:j + w] + cropped = images[:, :, i : i + h, j : j + w] return torch.nn.functional.interpolate( cropped, size=(target_height, target_width), - mode='bilinear', + mode="bilinear", align_corners=False, ) @@ -608,15 +582,15 @@ def random_resized_crop_with_shift( w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] out = torch.zeros((3, t, target_height, target_width)) for ind in range(t): - out[:, ind:ind + 1, :, :] = torch.nn.functional.interpolate( + out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate( images[ :, - ind:ind + 1, - i_s[ind]:i_s[ind] + h_s[ind], - j_s[ind]:j_s[ind] + w_s[ind], + ind : ind + 1, + i_s[ind] : i_s[ind] + h_s[ind], + j_s[ind] : j_s[ind] + w_s[ind], ], size=(target_height, target_width), - mode='bilinear', + mode="bilinear", align_corners=False, ) return out @@ -625,7 +599,7 @@ def random_resized_crop_with_shift( def create_random_augment( input_size, auto_augment=None, - interpolation='bilinear', + interpolation="bilinear", ): """ Get video randaug transform. @@ -648,13 +622,11 @@ def create_random_augment( img_size_min = min(img_size) else: img_size_min = img_size - aa_params = {'translate_const': int(img_size_min * 0.45)} - if interpolation and interpolation != 'random': - aa_params['interpolation'] = _pil_interp(interpolation) - if auto_augment.startswith('rand'): - return transforms.Compose( - [rand_augment_transform(auto_augment, aa_params)] - ) + aa_params = {"translate_const": int(img_size_min * 0.45)} + if interpolation and interpolation != "random": + aa_params["interpolation"] = _pil_interp(interpolation) + if auto_augment.startswith("rand"): + return transforms.Compose([rand_augment_transform(auto_augment, aa_params)]) raise NotImplementedError @@ -668,9 +640,7 @@ def random_sized_crop_img( """ Performs Inception-style cropping (used for training). """ - assert ( - len(im.shape) == 3 - ), 'Currently only support image for random_sized_crop' + assert len(im.shape) == 3, "Currently only support image for random_sized_crop" h, w = im.shape[1:3] i, j, h, w = _get_param_spatial_crop( scale=jitter_scale, @@ -681,11 +651,11 @@ def random_sized_crop_img( log_scale=False, switch_hw=True, ) - cropped = im[:, i:i + h, j:j + w] + cropped = im[:, i : i + h, j : j + w] return torch.nn.functional.interpolate( cropped.unsqueeze(0), size=(size, size), - mode='bilinear', + mode="bilinear", align_corners=False, ).squeeze(0) @@ -711,16 +681,16 @@ def __init__( size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), - interpolation='bilinear', + interpolation="bilinear", ): if isinstance(size, tuple): self.size = size else: self.size = (size, size) if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - print('range should be of kind (min, max)') + print("range should be of kind (min, max)") - if interpolation == 'random': + if interpolation == "random": self.interpolation = _RANDOM_INTERPOLATION else: self.interpolation = _pil_interp(interpolation) @@ -784,19 +754,15 @@ def __call__(self, img): def __repr__(self): if isinstance(self.interpolation, (tuple, list)): - interpolate_str = ' '.join( + interpolate_str = " ".join( [_pil_interpolation_to_str[x] for x in self.interpolation] ) else: interpolate_str = _pil_interpolation_to_str[self.interpolation] - format_string = self.__class__.__name__ + '(size={0}'.format(self.size) - format_string += ', scale={0}'.format( - tuple(round(s, 4) for s in self.scale) - ) - format_string += ', ratio={0}'.format( - tuple(round(r, 4) for r in self.ratio) - ) - format_string += ', interpolation={0})'.format(interpolate_str) + format_string = self.__class__.__name__ + "(size={0}".format(self.size) + format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) + format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) + format_string += ", interpolation={0})".format(interpolate_str) return format_string @@ -833,12 +799,12 @@ def __call__(self, clip): if isinstance(clip[0], np.ndarray): return [np.fliplr(img) for img in clip] elif isinstance(clip[0], PIL.Image.Image): - return [ - img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip - ] + return [img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip] else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - ' but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + " but got list of {0}".format(type(clip[0])) + ) return clip @@ -852,7 +818,7 @@ class RandomResize(object): size (tuple): (widht, height) """ - def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + def __init__(self, ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation="nearest"): self.ratio = ratio self.interpolation = interpolation @@ -867,8 +833,7 @@ def __call__(self, clip): new_w = int(im_w * scaling_factor) new_h = int(im_h * scaling_factor) new_size = (new_w, new_h) - resized = FF.resize_clip( - clip, new_size, interpolation=self.interpolation) + resized = FF.resize_clip(clip, new_size, interpolation=self.interpolation) return resized @@ -882,13 +847,12 @@ class Resize(object): size (tuple): (widht, height) """ - def __init__(self, size, interpolation='nearest'): + def __init__(self, size, interpolation="nearest"): self.size = size self.interpolation = interpolation def __call__(self, clip): - resized = FF.resize_clip( - clip, self.size, interpolation=self.interpolation) + resized = FF.resize_clip(clip, self.size, interpolation=self.interpolation) return resized @@ -919,14 +883,18 @@ def __call__(self, clip): elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) if w > im_w or h > im_h: error_msg = ( - 'Initial image size should be larger then ' - 'cropped size but got cropped sizes : ({w}, {h}) while ' - 'initial image is ({im_w}, {im_h})'.format( - im_w=im_w, im_h=im_h, w=w, h=h)) + "Initial image size should be larger then " + "cropped size but got cropped sizes : ({w}, {h}) while " + "initial image is ({im_w}, {im_h})".format( + im_w=im_w, im_h=im_h, w=w, h=h + ) + ) raise ValueError(error_msg) x1 = random.randint(0, im_w - w) @@ -963,8 +931,10 @@ def __call__(self, clip): elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) if w != im_w and h != im_h: clip = FF.resize_clip(clip, self.size, interpolation="bilinear") im_h, im_w, im_c = clip[0].shape @@ -972,7 +942,7 @@ def __call__(self, clip): step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) cropped = [] for i in range(3): - if (im_h > self.size[0]): + if im_h > self.size[0]: x1 = 0 y1 = i * step cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) @@ -995,13 +965,11 @@ class RandomRotation(object): def __init__(self, degrees): if isinstance(degrees, numbers.Number): if degrees < 0: - raise ValueError('If degrees is a single number,' - 'must be positive') + raise ValueError("If degrees is a single number," "must be positive") degrees = (-degrees, degrees) else: if len(degrees) != 2: - raise ValueError('If degrees is a sequence,' - 'it must be of len 2.') + raise ValueError("If degrees is a sequence," "it must be of len 2.") self.degrees = degrees @@ -1014,14 +982,17 @@ def __call__(self, clip): PIL.Image or numpy.ndarray: Cropped list of images """ import skimage + angle = random.uniform(self.degrees[0], self.degrees[1]) if isinstance(clip[0], np.ndarray): rotated = [skimage.transform.rotate(img, angle) for img in clip] elif isinstance(clip[0], PIL.Image.Image): rotated = [img.rotate(angle) for img in clip] else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) return rotated @@ -1053,18 +1024,22 @@ def __call__(self, clip): elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) if w > im_w or h > im_h: error_msg = ( - 'Initial image size should be larger then ' - 'cropped size but got cropped sizes : ({w}, {h}) while ' - 'initial image is ({im_w}, {im_h})'.format( - im_w=im_w, im_h=im_h, w=w, h=h)) + "Initial image size should be larger then " + "cropped size but got cropped sizes : ({w}, {h}) while " + "initial image is ({im_w}, {im_h})".format( + im_w=im_w, im_h=im_h, w=w, h=h + ) + ) raise ValueError(error_msg) - x1 = int(round((im_w - w) / 2.)) - y1 = int(round((im_h - h) / 2.)) + x1 = int(round((im_w - w) / 2.0)) + y1 = int(round((im_h - h) / 2.0)) cropped = FF.crop_clip(clip, y1, x1, h, w) return cropped @@ -1093,20 +1068,17 @@ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): def get_params(self, brightness, contrast, saturation, hue): if brightness > 0: - brightness_factor = random.uniform( - max(0, 1 - brightness), 1 + brightness) + brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) else: brightness_factor = None if contrast > 0: - contrast_factor = random.uniform( - max(0, 1 - contrast), 1 + contrast) + contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) else: contrast_factor = None if saturation > 0: - saturation_factor = random.uniform( - max(0, 1 - saturation), 1 + saturation) + saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation) else: saturation_factor = None @@ -1124,22 +1096,36 @@ def __call__(self, clip): list PIL.Image : list of transformed PIL.Image """ if isinstance(clip[0], np.ndarray): - raise TypeError( - 'Color jitter not yet implemented for numpy arrays') + raise TypeError("Color jitter not yet implemented for numpy arrays") elif isinstance(clip[0], PIL.Image.Image): brightness, contrast, saturation, hue = self.get_params( - self.brightness, self.contrast, self.saturation, self.hue) + self.brightness, self.contrast, self.saturation, self.hue + ) # Create img transform function sequence img_transforms = [] if brightness is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + img_transforms.append( + lambda img: torchvision.transforms.functional.adjust_brightness( + img, brightness + ) + ) if saturation is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + img_transforms.append( + lambda img: torchvision.transforms.functional.adjust_saturation( + img, saturation + ) + ) if hue is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + img_transforms.append( + lambda img: torchvision.transforms.functional.adjust_hue(img, hue) + ) if contrast is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + img_transforms.append( + lambda img: torchvision.transforms.functional.adjust_contrast( + img, contrast + ) + ) random.shuffle(img_transforms) # Apply to all images @@ -1150,8 +1136,10 @@ def __call__(self, clip): jittered_clip.append(jittered_img) else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) + raise TypeError( + "Expected numpy.ndarray or PIL.Image" + + "but got list of {0}".format(type(clip[0])) + ) return jittered_clip @@ -1181,4 +1169,6 @@ def __call__(self, clip): return FF.normalize(clip, self.mean, self.std) def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + return self.__class__.__name__ + "(mean={0}, std={1})".format( + self.mean, self.std + ) diff --git a/src/datasets/utils/video/volume_transforms.py b/src/datasets/utils/video/volume_transforms.py index 0a01bb36..cf42d06e 100644 --- a/src/datasets/utils/video/volume_transforms.py +++ b/src/datasets/utils/video/volume_transforms.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/datasets/utils/weighted_sampler.py b/src/datasets/utils/weighted_sampler.py index fd40825e..f411c8f3 100644 --- a/src/datasets/utils/weighted_sampler.py +++ b/src/datasets/utils/weighted_sampler.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -10,12 +10,7 @@ import numpy as np import torch -from torch.utils.data import ( - Dataset, - Sampler, - DistributedSampler, - WeightedRandomSampler -) +from torch.utils.data import Dataset, Sampler, DistributedSampler, WeightedRandomSampler class DatasetFromSampler(Dataset): @@ -34,7 +29,7 @@ def __len__(self) -> int: class DistributedSamplerWrapper(DistributedSampler): - """ Convert any Pytorch Sampler to a DistributedSampler """ + """Convert any Pytorch Sampler to a DistributedSampler""" def __init__( self, @@ -59,7 +54,7 @@ def __iter__(self) -> Iterator[int]: class CustomWeightedRandomSampler(WeightedRandomSampler): - """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ + """Generalized WeightedRandomSampler to allow for more than 2^24 samples""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -69,7 +64,7 @@ def __iter__(self): range(0, len(self.weights)), size=self.num_samples, p=self.weights.numpy() / torch.sum(self.weights).numpy(), - replace=self.replacement + replace=self.replacement, ) rand_tensor = torch.from_numpy(rand_tensor) return iter(rand_tensor.tolist()) @@ -85,9 +80,8 @@ def __init__( shuffle: bool = True, ): weighted_sampler = CustomWeightedRandomSampler( - weights=weights, - num_samples=len(weights), - replacement=False) + weights=weights, num_samples=len(weights), replacement=False + ) super(DistributedWeightedSampler, self).__init__( sampler=weighted_sampler, diff --git a/src/datasets/video_dataset.py b/src/datasets/video_dataset.py index b05cc701..afdbe4c8 100644 --- a/src/datasets/video_dataset.py +++ b/src/datasets/video_dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -58,21 +58,18 @@ def make_videodataset( filter_long_videos=filter_long_videos, duration=duration, shared_transform=shared_transform, - transform=transform) + transform=transform, + ) - logger.info('VideoDataset dataset created') + logger.info("VideoDataset dataset created") if datasets_weights is not None: dist_sampler = DistributedWeightedSampler( - dataset.sample_weights, - num_replicas=world_size, - rank=rank, - shuffle=True) + dataset.sample_weights, num_replicas=world_size, rank=rank, shuffle=True + ) else: dist_sampler = torch.utils.data.distributed.DistributedSampler( - dataset, - num_replicas=world_size, - rank=rank, - shuffle=True) + dataset, num_replicas=world_size, rank=rank, shuffle=True + ) data_loader = torch.utils.data.DataLoader( dataset, @@ -82,14 +79,15 @@ def make_videodataset( drop_last=drop_last, pin_memory=pin_mem, num_workers=num_workers, - persistent_workers=num_workers > 0) - logger.info('VideoDataset unsupervised data loader created') + persistent_workers=num_workers > 0, + ) + logger.info("VideoDataset unsupervised data loader created") return dataset, data_loader, dist_sampler class VideoDataset(torch.utils.data.Dataset): - """ Video classification dataset. """ + """Video classification dataset.""" def __init__( self, @@ -119,22 +117,26 @@ def __init__( self.filter_long_videos = filter_long_videos self.duration = duration + self.frame_sample_rate = None + if VideoReader is None: - raise ImportError('Unable to import "decord" which is required to read videos.') + raise ImportError( + 'Unable to import "decord" which is required to read videos.' + ) # Load video paths and labels samples, labels = [], [] self.num_samples_per_dataset = [] for data_path in self.data_paths: - if data_path[-4:] == '.csv': + if data_path[-4:] == ".csv": data = pd.read_csv(data_path, header=None, delimiter=" ") samples += list(data.values[:, 0]) labels += list(data.values[:, 1]) num_samples = len(data) self.num_samples_per_dataset.append(num_samples) - elif data_path[-4:] == '.npy': + elif data_path[-4:] == ".npy": data = np.load(data_path, allow_pickle=True) data = list(map(lambda x: repr(x)[1:-1], data)) samples += data @@ -169,10 +171,10 @@ def __getitem__(self, index): label = self.labels[index] def split_into_clips(video): - """ Split video into a list of clips """ + """Split video into a list of clips""" fpc = self.frames_per_clip nc = self.num_clips - return [video[i*fpc:(i+1)*fpc] for i in range(nc)] + return [video[i * fpc : (i + 1) * fpc] for i in range(nc)] # Parse video into frames & apply data augmentations if self.shared_transform is not None: @@ -181,22 +183,42 @@ def split_into_clips(video): if self.transform is not None: buffer = [self.transform(clip) for clip in buffer] - return buffer, label, clip_indices + # Load Action Data + action_filepath = os.path.join(os.path.dirname(sample), "drive_data.csv") + action_df = pd.read_csv(action_filepath) + action_labels = self.get_action_labels_for_clip(action_df, clip_indices) + + return buffer, label, clip_indices, action_labels + + def get_action_labels_for_clip(self, action_df, clip_indices): + # Convert video frame indices to timestamps in seconds + frame_timestamps = clip_indices / self.frame_sample_rate + + # Find the corresponding actions for each frame + action_labels = [] + for timestamp in frame_timestamps: + # Find the row with the nearest timestamp (modify this for interpolation if needed) + nearest_row = action_df.iloc[ + (action_df["timestamp"] - timestamp).abs().argmin() + ] + action_labels.append(nearest_row["action_name"]) + + return action_labels def loadvideo_decord(self, sample): - """ Load video content using Decord """ + """Load video content using Decord""" fname = sample if not os.path.exists(fname): - warnings.warn(f'video path not found {fname=}') + warnings.warn(f"video path not found {fname=}") return [], None _fsize = os.path.getsize(fname) if _fsize < 1 * 1024: # avoid hanging issue - warnings.warn(f'video too short {fname=}') + warnings.warn(f"video too short {fname=}") return [], None if _fsize > self.filter_long_videos: - warnings.warn(f'skipping long video of size {_fsize=} (bytes)') + warnings.warn(f"skipping long video of size {_fsize=} (bytes)") return [], None try: @@ -215,7 +237,7 @@ def loadvideo_decord(self, sample): clip_len = int(fpc * fstp) if self.filter_short_videos and len(vr) < clip_len: - warnings.warn(f'skipping video of length {len(vr)}') + warnings.warn(f"skipping video of length {len(vr)}") return [], None vr.seek(0) # Go to start of video before sampling frames @@ -235,7 +257,7 @@ def loadvideo_decord(self, sample): end_indx = np.random.randint(clip_len, partition_len) start_indx = end_indx - clip_len indices = np.linspace(start_indx, end_indx, num=fpc) - indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) + indices = np.clip(indices, start_indx, end_indx - 1).astype(np.int64) # -- indices = indices + i * partition_len else: @@ -244,8 +266,13 @@ def loadvideo_decord(self, sample): # we reach the desired clip length if not self.allow_clip_overlap: indices = np.linspace(0, partition_len, num=partition_len // fstp) - indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) - indices = np.clip(indices, 0, partition_len-1).astype(np.int64) + indices = np.concatenate( + ( + indices, + np.ones(fpc - partition_len // fstp) * partition_len, + ) + ) + indices = np.clip(indices, 0, partition_len - 1).astype(np.int64) # -- indices = indices + i * partition_len @@ -254,8 +281,13 @@ def loadvideo_decord(self, sample): else: sample_len = min(clip_len, len(vr)) - 1 indices = np.linspace(0, sample_len, num=sample_len // fstp) - indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) - indices = np.clip(indices, 0, sample_len-1).astype(np.int64) + indices = np.concatenate( + ( + indices, + np.ones(fpc - sample_len // fstp) * sample_len, + ) + ) + indices = np.clip(indices, 0, sample_len - 1).astype(np.int64) # -- clip_step = 0 if len(vr) > clip_len: @@ -266,6 +298,10 @@ def loadvideo_decord(self, sample): all_indices.extend(list(indices)) buffer = vr.get_batch(all_indices).asnumpy() + + # Added the following line to extract the frame rate from video metadata + self.frame_sample_rate = vr.get_avg_fps() + return buffer, clip_indices def __len__(self): diff --git a/src/masks/__init__.py b/src/masks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/masks/default.py b/src/masks/default.py index 2810c0a1..a95bbe02 100644 --- a/src/masks/default.py +++ b/src/masks/default.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/masks/multiblock3d.py b/src/masks/multiblock3d.py index a7bbc3e1..8b9f0483 100644 --- a/src/masks/multiblock3d.py +++ b/src/masks/multiblock3d.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -36,12 +36,12 @@ def __init__( num_frames=num_frames, spatial_patch_size=patch_size, temporal_patch_size=tubelet_size, - spatial_pred_mask_scale=m.get('spatial_scale'), - temporal_pred_mask_scale=m.get('temporal_scale'), - aspect_ratio=m.get('aspect_ratio'), - npred=m.get('num_blocks'), - max_context_frames_ratio=m.get('max_temporal_keep', 1.0), - max_keep=m.get('max_keep', None), + spatial_pred_mask_scale=m.get("spatial_scale"), + temporal_pred_mask_scale=m.get("temporal_scale"), + aspect_ratio=m.get("aspect_ratio"), + npred=m.get("num_blocks"), + max_context_frames_ratio=m.get("max_temporal_keep", 1.0), + max_keep=m.get("max_keep", None), ) self.mask_generators.append(mask_generator) @@ -62,7 +62,21 @@ def __call__(self, batch): return collated_batch, collated_masks_enc, collated_masks_pred +class MaskCollatorWithActions(MaskCollator): + def __call__(self, batch): + images, maneuver_ids = zip(*batch) + # collated_images = torch.utils.data.default_collate(images) + collated_maneuvers = torch.tensor(maneuver_ids) + collated_images = list(images) # Keep images as a list of PIL images + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(len(collated_images)) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + return collated_images, collated_maneuvers, collated_masks_enc, collated_masks_pred + class _MaskGenerator(object): def __init__( @@ -80,9 +94,12 @@ def __init__( ): super(_MaskGenerator, self).__init__() if not isinstance(crop_size, tuple): - crop_size = (crop_size, ) * 2 + crop_size = (crop_size,) * 2 self.crop_size = crop_size - self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.height, self.width = ( + crop_size[0] // spatial_patch_size, + crop_size[1] // spatial_patch_size, + ) self.duration = num_frames // temporal_patch_size self.spatial_patch_size = spatial_patch_size @@ -92,9 +109,11 @@ def __init__( self.spatial_pred_mask_scale = spatial_pred_mask_scale self.temporal_pred_mask_scale = temporal_pred_mask_scale self.npred = npred - self.max_context_duration = max(1, int(self.duration * max_context_frames_ratio)) # maximum number of time-steps (frames) spanned by context mask + self.max_context_duration = max( + 1, int(self.duration * max_context_frames_ratio) + ) # maximum number of time-steps (frames) spanned by context mask self.max_keep = max_keep # maximum number of patches to keep in context - self._itr_counter = Value('i', -1) # collator is shared across worker processes + self._itr_counter = Value("i", -1) # collator is shared across worker processes def step(self): i = self._itr_counter @@ -104,11 +123,7 @@ def step(self): return v def _sample_block_size( - self, - generator, - temporal_scale, - spatial_scale, - aspect_ratio_scale + self, generator, temporal_scale, spatial_scale, aspect_ratio_scale ): # -- Sample temporal block mask scale _rand = torch.rand(1, generator=generator).item() @@ -142,12 +157,12 @@ def _sample_block_mask(self, b_size): start = torch.randint(0, self.duration - t + 1, (1,)) mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) - mask[start:start+t, top:top+h, left:left+w] = 0 + mask[start : start + t, top : top + h, left : left + w] = 0 # Context mask will only span the first X frames # (X=self.max_context_frames) if self.max_context_duration < self.duration: - mask[self.max_context_duration:, :, :] = 0 + mask[self.max_context_duration :, :, :] = 0 # -- return mask @@ -176,7 +191,9 @@ def __call__(self, batch_size): empty_context = True while empty_context: - mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + mask_e = torch.ones( + (self.duration, self.height, self.width), dtype=torch.int32 + ) for _ in range(self.npred): mask_e *= self._sample_block_mask(p_size) mask_e = mask_e.flatten() diff --git a/src/masks/random_tube.py b/src/masks/random_tube.py index 84c06402..9d8a82ec 100644 --- a/src/masks/random_tube.py +++ b/src/masks/random_tube.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -35,7 +35,7 @@ def __init__( num_frames=num_frames, spatial_patch_size=patch_size, temporal_patch_size=tubelet_size, - ratio=m.get('ratio'), + ratio=m.get("ratio", 0.9), ) self.mask_generators.append(mask_generator) @@ -56,6 +56,20 @@ def __call__(self, batch): return collated_batch, collated_masks_enc, collated_masks_pred +class MaskCollatorWithActions(MaskCollator): + def __call__(self, batch): + images, maneuver_ids = zip(*batch) + # collated_images = torch.utils.data.default_collate(images) + collated_maneuvers = torch.tensor(maneuver_ids) + collated_images = list(images) # Keep images as a list of PIL images + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(len(collated_images)) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_images, collated_maneuvers, collated_masks_enc, collated_masks_pred class _MaskGenerator(object): @@ -69,21 +83,24 @@ def __init__( ): super(_MaskGenerator, self).__init__() if not isinstance(crop_size, tuple): - crop_size = (crop_size, ) * 2 + crop_size = (crop_size,) * 2 self.crop_size = crop_size - self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.height, self.width = ( + crop_size[0] // spatial_patch_size, + crop_size[1] // spatial_patch_size, + ) self.duration = num_frames // temporal_patch_size self.spatial_patch_size = spatial_patch_size self.temporal_patch_size = temporal_patch_size - self.num_patches_spatial = self.height*self.width + self.num_patches_spatial = self.height * self.width self.ratio = ratio - self.num_keep_spatial = int(self.num_patches_spatial*(1.-self.ratio)) + self.num_keep_spatial = int(self.num_patches_spatial * (1.0 - self.ratio)) self.num_keep = self.num_keep_spatial * self.duration - self._itr_counter = Value('i', -1) # collator is shared across worker processes + self._itr_counter = Value("i", -1) # collator is shared across worker processes def step(self): i = self._itr_counter @@ -94,10 +111,12 @@ def step(self): def __call__(self, batch_size): def sample_mask(): - mask = np.hstack([ - np.zeros(self.num_patches_spatial - self.num_keep_spatial), - np.ones(self.num_keep_spatial), - ]) + mask = np.hstack( + [ + np.zeros(self.num_patches_spatial - self.num_keep_spatial), + np.ones(self.num_keep_spatial), + ] + ) np.random.shuffle(mask) mask = torch.tensor(np.tile(mask, (self.duration, 1))) mask = mask.flatten() diff --git a/src/masks/utils.py b/src/masks/utils.py index ca04af1f..bb0ad76a 100644 --- a/src/masks/utils.py +++ b/src/masks/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/models/action_encoders.py b/src/models/action_encoders.py new file mode 100644 index 00000000..d6c37323 --- /dev/null +++ b/src/models/action_encoders.py @@ -0,0 +1,35 @@ +import math +from functools import partial + +import torch +import torch.nn as nn + + +class ActionEncoderDiscrete(nn.Module): + def __init__(self, num_actions, embed_dim, hidden_dim): + super(ActionEncoderDiscrete, self).__init__() + self.embedding = nn.Embedding(num_actions, embed_dim) + self.linear = nn.Linear(embed_dim, hidden_dim) + + def forward(self, actions): + embedded_actions = self.embedding(actions) + encoded_actions = self.linear(embedded_actions) + return encoded_actions + + +class ActionEncoderContinuous(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super(ActionEncoderContinuous, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + *[ + nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) + for _ in range(num_layers - 1) + ], + nn.Linear(hidden_dim, hidden_dim) + ) + + def forward(self, actions): + encoded_actions = self.mlp(actions) + return encoded_actions diff --git a/src/models/attentive_pooler.py b/src/models/attentive_pooler.py index ecd9986a..122a92b4 100644 --- a/src/models/attentive_pooler.py +++ b/src/models/attentive_pooler.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -10,16 +10,13 @@ import torch import torch.nn as nn -from src.models.utils.modules import ( - Block, - CrossAttention, - CrossAttentionBlock -) +from src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock from src.utils.tensors import trunc_normal_ class AttentivePooler(nn.Module): - """ Attentive Pooler """ + """Attentive Pooler""" + def __init__( self, num_queries=1, @@ -30,7 +27,7 @@ def __init__( norm_layer=nn.LayerNorm, init_std=0.02, qkv_bias=True, - complete_block=True + complete_block=True, ): super().__init__() self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) @@ -42,24 +39,28 @@ def __init__( num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - norm_layer=norm_layer) + norm_layer=norm_layer, + ) else: self.cross_attention_block = CrossAttention( - dim=embed_dim, - num_heads=num_heads, - qkv_bias=qkv_bias) + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias + ) self.blocks = None if depth > 1: - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=False, - norm_layer=norm_layer) - for i in range(depth-1)]) + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=False, + norm_layer=norm_layer, + ) + for i in range(depth - 1) + ] + ) self.init_std = init_std trunc_normal_(self.query_tokens, std=self.init_std) @@ -103,7 +104,8 @@ def forward(self, x): class AttentiveClassifier(nn.Module): - """ Attentive Classifier """ + """Attentive Classifier""" + def __init__( self, embed_dim=768, diff --git a/src/models/predictor.py b/src/models/predictor.py index 2dd9a38b..0f76f457 100644 --- a/src/models/predictor.py +++ b/src/models/predictor.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -13,15 +13,13 @@ from src.models.utils.modules import Block from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed -from src.utils.tensors import ( - trunc_normal_, - repeat_interleave_batch -) +from src.utils.tensors import trunc_normal_, repeat_interleave_batch from src.masks.utils import apply_masks class VisionTransformerPredictor(nn.Module): - """ Vision Transformer """ + """Vision Transformer""" + def __init__( self, img_size=224, @@ -54,10 +52,12 @@ def __init__( self.num_mask_tokens = 0 if use_mask_tokens: self.num_mask_tokens = num_mask_tokens - self.mask_tokens = nn.ParameterList([ - nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) - for i in range(num_mask_tokens) - ]) + self.mask_tokens = nn.ParameterList( + [ + nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + for i in range(num_mask_tokens) + ] + ) # Determine positional embedding self.input_size = img_size @@ -77,32 +77,35 @@ def __init__( * (img_size // patch_size) ) else: - self.num_patches = num_patches = ( - (img_size // patch_size) - * (img_size // patch_size) + self.num_patches = num_patches = (img_size // patch_size) * ( + img_size // patch_size ) # Position embedding self.uniform_power = uniform_power self.predictor_pos_embed = None self.predictor_pos_embed = nn.Parameter( - torch.zeros(1, num_patches, predictor_embed_dim), - requires_grad=False) + torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False + ) # Attention Blocks - self.predictor_blocks = nn.ModuleList([ - Block( - dim=predictor_embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=nn.GELU, - attn_drop=attn_drop_rate, - grid_size=grid_size, - grid_depth=grid_depth, - norm_layer=norm_layer) - for i in range(depth)]) + self.predictor_blocks = nn.ModuleList( + [ + Block( + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + attn_drop=attn_drop_rate, + grid_size=grid_size, + grid_depth=grid_depth, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) # Normalize & project back to input dimension self.predictor_norm = norm_layer(predictor_embed_dim) @@ -128,7 +131,7 @@ def _init_pos_embed(self, pos_embed): grid_size, grid_depth, cls_token=False, - uniform_power=self.uniform_power + uniform_power=self.uniform_power, ) else: sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) @@ -155,20 +158,26 @@ def diffusion(self, x, noise_beta=(0.5, 1.0), steps=1000): # Prepare diffusion noise schedule b1, b2 = noise_beta - beta_scheduler = (b1 + i*(b2-b1)/steps for i in range(steps)) + beta_scheduler = (b1 + i * (b2 - b1) / steps for i in range(steps)) alpha_scheduler = [] _alpha = 1.0 for _beta in beta_scheduler: - _alpha *= 1.-_beta + _alpha *= 1.0 - _beta alpha_scheduler += [_alpha] # Sample diffusion time step T = torch.randint(0, steps, (len(x),)) - alpha = torch.tensor(alpha_scheduler, device=x.device)[T].unsqueeze(-1).unsqueeze(-1) + alpha = ( + torch.tensor(alpha_scheduler, device=x.device)[T] + .unsqueeze(-1) + .unsqueeze(-1) + ) # Normalize features and apply noise x = torch.nn.functional.layer_norm(x, (x.size(-1),)) - x = alpha**0.5 * x + (1.-alpha)**0.5 * torch.randn(x.shape, device=x.device) + x = alpha**0.5 * x + (1.0 - alpha) ** 0.5 * torch.randn( + x.shape, device=x.device + ) return x def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): @@ -179,7 +188,9 @@ def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): :params masks_tgt: indices of target tokens in input """ - assert (masks_ctxt is not None) and (masks_tgt is not None), 'Cannot run predictor without mask indices' + assert (masks_ctxt is not None) and ( + masks_tgt is not None + ), "Cannot run predictor without mask indices" if not isinstance(masks_ctxt, list): masks_ctxt = [masks_ctxt] @@ -241,6 +252,6 @@ def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): def vit_predictor(**kwargs): model = VisionTransformerPredictor( - mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs) + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) return model diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/models/utils/combine_encodings.py b/src/models/utils/combine_encodings.py new file mode 100644 index 00000000..c8f0c223 --- /dev/null +++ b/src/models/utils/combine_encodings.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def combine_encodings_concat(z, z_a): + """ + Concatenation: Concatenate the encoded video clips and actions along the feature dimension. + """ + z_combined = torch.cat([z, z_a], dim=-1) + return z_combined + + +def combine_encodings_add(z, z_a): + """ + Addition: Add the encoded video clips and actions element-wise. + """ + z_combined = z + z_a + return z_combined + + +class AttentionFusion(nn.Module): + """ + Attention-based fusion: Use an attention mechanism to weight the importance of video clips and actions based on their relevance. + """ + + def __init__(self, hidden_dim): + super(AttentionFusion, self).__init__() + self.attention = nn.MultiheadAttention(hidden_dim, num_heads=8) + + def forward(self, z, z_a): + z_combined, _ = self.attention(z, z_a, z_a) + return z_combined diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py index dc470d9b..f95ea0a1 100644 --- a/src/models/utils/modules.py +++ b/src/models/utils/modules.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -17,7 +17,7 @@ def __init__( hidden_features=None, out_features=None, act_layer=nn.GELU, - drop=0. + drop=0.0, ): super().__init__() out_features = out_features or in_features @@ -43,14 +43,14 @@ def __init__( num_heads=8, qkv_bias=False, qk_scale=None, - attn_drop=0., - proj_drop=0., - use_sdpa=True + attn_drop=0.0, + proj_drop=0.0, + use_sdpa=True, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) @@ -60,18 +60,24 @@ def __init__( def forward(self, x, mask=None): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): - x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.proj_drop_prob + ) attn = None else: attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v) + x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) @@ -83,11 +89,11 @@ def __init__( self, dim, num_heads, - mlp_ratio=4., + mlp_ratio=4.0, qkv_bias=False, qk_scale=None, - drop=0., - attn_drop=0., + drop=0.0, + attn_drop=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, grid_size=None, @@ -101,7 +107,8 @@ def __init__( qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop) + proj_drop=drop, + ) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -109,7 +116,8 @@ def __init__( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, - drop=drop) + drop=drop, + ) def forward(self, x, return_attention=False, mask=None): y, attn = self.attn(self.norm1(x), mask=mask) @@ -121,28 +129,30 @@ def forward(self, x, return_attention=False, mask=None): class CrossAttention(nn.Module): - def __init__( - self, - dim, - num_heads=12, - qkv_bias=False, - use_sdpa=True - ): + def __init__(self, dim, num_heads=12, qkv_bias=False, use_sdpa=True): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = head_dim ** -0.5 + self.scale = head_dim**-0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) + self.kv = nn.Linear(dim, int(dim * 2), bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_sdpa = use_sdpa def forward(self, q, x): B, n, C = q.shape - q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q = ( + self.q(q) + .reshape(B, n, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) B, N, C = x.shape - kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = ( + self.kv(x) + .reshape(B, N, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) if self.use_sdpa: @@ -151,11 +161,11 @@ def forward(self, q, x): else: xattn = (q @ k.transpose(-2, -1)) * self.scale xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) - q = (xattn @ v) + q = xattn @ v q = q.transpose(1, 2).reshape(B, n, C) q = self.proj(q) - + return q @@ -164,17 +174,19 @@ def __init__( self, dim, num_heads, - mlp_ratio=4., + mlp_ratio=4.0, qkv_bias=False, act_layer=nn.GELU, - norm_layer=nn.LayerNorm + norm_layer=nn.LayerNorm, ): super().__init__() self.norm1 = norm_layer(dim) self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + self.mlp = MLP( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer + ) def forward(self, q, x): y = self.xattn(q, self.norm1(x)) diff --git a/src/models/utils/multimask.py b/src/models/utils/multimask.py index d4800869..db2f8410 100644 --- a/src/models/utils/multimask.py +++ b/src/models/utils/multimask.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the diff --git a/src/models/utils/patch_embed.py b/src/models/utils/patch_embed.py index 4ff4de51..36207155 100644 --- a/src/models/utils/patch_embed.py +++ b/src/models/utils/patch_embed.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -12,15 +12,13 @@ class PatchEmbed(nn.Module): """ Image to Patch Embedding """ - def __init__( - self, - patch_size=16, - in_chans=3, - embed_dim=768 - ): + + def __init__(self, patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.patch_size = patch_size - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) def forward(self, x): B, C, H, W = x.shape @@ -43,15 +41,33 @@ def __init__( super().__init__() self.patch_size = patch_size self.tubelet_size = tubelet_size - - self.proj = nn.Conv3d( + self.proj_video = nn.Conv3d( in_channels=in_chans, out_channels=embed_dim, kernel_size=(tubelet_size, patch_size, patch_size), stride=(tubelet_size, patch_size, patch_size), ) + self.proj_image = nn.Conv2d( + in_channels=in_chans, + out_channels=embed_dim, + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + ) def forward(self, x, **kwargs): - B, C, T, H, W = x.shape - x = self.proj(x).flatten(2).transpose(1, 2) + if x is None: + return None + + if x.ndim == 5: # Video input + B, C, T, H, W = x.shape + x = self.proj_video(x) + x = x.flatten(2).transpose(1, 2) + elif x.ndim == 4: # Image input + B, C, H, W = x.shape + x = self.proj_image(x) + x = x.flatten(2).transpose(1, 2) + + else: + raise ValueError(f"Unsupported input shape: {x.shape}") + return x diff --git a/src/models/utils/pos_embs.py b/src/models/utils/pos_embs.py index d1d82e21..72be0bfb 100644 --- a/src/models/utils/pos_embs.py +++ b/src/models/utils/pos_embs.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -9,11 +9,7 @@ def get_3d_sincos_pos_embed( - embed_dim, - grid_size, - grid_depth, - cls_token=False, - uniform_power=False + embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=False ): """ grid_size: int of the grid height and width @@ -25,14 +21,16 @@ def get_3d_sincos_pos_embed( grid_d = np.arange(grid_depth, dtype=float) grid_h = np.arange(grid_size, dtype=float) grid_w = np.arange(grid_size, dtype=float) - grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] + grid_h, grid_d, grid_w = np.meshgrid( + grid_h, grid_d, grid_w + ) # order of meshgrid is very important for indexing as [d,h,w] if not uniform_power: h_embed_dim = embed_dim // 4 w_embed_dim = embed_dim // 4 d_embed_dim = embed_dim // 2 else: - h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) + h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim / 6) * 2) emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) @@ -53,7 +51,9 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_h = np.arange(grid_size, dtype=float) grid_w = np.arange(grid_size, dtype=float) - grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] + grid_w, grid_h = np.meshgrid( + grid_w, grid_h + ) # order of meshgrid is very important for indexing as [h, w] emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) @@ -86,11 +86,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=float) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) diff --git a/src/models/vision_transformer.py b/src/models/vision_transformer.py index a8748dfd..7cfd3390 100644 --- a/src/models/vision_transformer.py +++ b/src/models/vision_transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -19,7 +19,8 @@ class VisionTransformer(nn.Module): - """ Vision Transformer """ + """Vision Transformer""" + def __init__( self, img_size=224, @@ -62,7 +63,8 @@ def __init__( patch_size=patch_size, tubelet_size=tubelet_size, in_chans=in_chans, - embed_dim=embed_dim) + embed_dim=embed_dim, + ) self.num_patches = ( (num_frames // tubelet_size) * (img_size // patch_size) @@ -70,36 +72,36 @@ def __init__( ) else: self.patch_embed = PatchEmbed( - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim) - self.num_patches = ( - (img_size // patch_size) - * (img_size // patch_size) + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim ) + self.num_patches = (img_size // patch_size) * (img_size // patch_size) # Position embedding self.uniform_power = uniform_power self.pos_embed = None self.pos_embed = nn.Parameter( - torch.zeros(1, self.num_patches, embed_dim), - requires_grad=False) + torch.zeros(1, self.num_patches, embed_dim), requires_grad=False + ) # Attention Blocks - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=nn.GELU, - grid_size=grid_size, - grid_depth=grid_depth, - attn_drop=attn_drop_rate, - norm_layer=norm_layer) - for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + grid_size=grid_size, + grid_depth=grid_depth, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) self.norm = norm_layer(embed_dim) # ------ initialize weights @@ -119,7 +121,7 @@ def _init_pos_embed(self, pos_embed): grid_size, grid_depth, cls_token=False, - uniform_power=self.uniform_power + uniform_power=self.uniform_power, ) else: sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) @@ -161,6 +163,8 @@ def forward(self, x, masks=None): :param x: input image/video :param masks: indices of patch tokens to mask (remove) """ + if x is None: + raise ValueError("Input tensor x cannot be None") if masks is not None and not isinstance(masks, list): masks = [masks] @@ -169,7 +173,9 @@ def forward(self, x, masks=None): pos_embed = self.pos_embed if pos_embed is not None: pos_embed = self.interpolate_pos_encoding(x, pos_embed) + x = self.patch_embed(x) + if pos_embed is not None: x += pos_embed B, N, D = x.shape @@ -193,115 +199,146 @@ def forward(self, x, masks=None): x = self.norm(x) return x - + def interpolate_pos_encoding(self, x, pos_embed): - _, N, dim = pos_embed.shape - if self.is_video: - - # If pos_embed already corret size, just return + if x.dim() == 5: # Video clip [B, C, T, H, W] _, _, T, H, W = x.shape - if H == self.input_size and W == self.input_size and T == self.num_frames: - return pos_embed - - # Convert depth, height, width of input to be measured in patches - # instead of pixels/frames - T = T // self.tubelet_size - H = H // self.patch_size - W = W // self.patch_size - - # Compute the initialized shape of the positional embedding measured - # in patches - N_t = self.num_frames // self.tubelet_size - N_h = N_w = self.input_size // self.patch_size - assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly' - - # Compute scale factor for spatio-temporal interpolation - scale_factor = (T/N_t, H/N_h, W/N_w) - - pos_embed = nn.functional.interpolate( - pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), - scale_factor=scale_factor, - mode='trilinear') - pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) - return pos_embed - - else: - - # If pos_embed already corret size, just return - _, _, H, W = x.shape + # ... (rest of the video interpolation logic remains the same) + else: # Image sequence [B, T, H, W] + _, T, H, W = x.shape + # If pos_embed already correct size, just return if H == self.input_size and W == self.input_size: + # Add a temporal dimension to the positional embedding + pos_embed = pos_embed.unsqueeze(1).repeat(1, T, 1, 1) return pos_embed # Compute scale factor for spatial interpolation npatch = (H // self.patch_size) * (W // self.patch_size) + # Assuming pos_embed was initialized with no temporal dimension, + # N should correspond to the number of patches in a single image + assert N == npatch, "Input image size doesn't match model's expected size" scale_factor = math.sqrt(npatch / N) + + # Repeat positional embedding to account for the temporal dimension + pos_embed = pos_embed.unsqueeze(1).repeat(1, T, 1, 1) + # 2D interpolation of positional embeddings (spatial dimensions) pos_embed = nn.functional.interpolate( - pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), - scale_factor=scale_factor, - mode='bicubic') - pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return pos_embed + pos_embed.reshape(1, T, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 4, 1, 2, 3), + scale_factor=(1.0, scale_factor, scale_factor), # Only interpolate spatial dimensions + mode="bicubic", + align_corners=False, + ) + pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) + return pos_embed def vit_tiny(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_small(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_base(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_large(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_huge(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_giant(patch_size=16, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + patch_size=patch_size, + embed_dim=1408, + depth=40, + num_heads=16, + mlp_ratio=48 / 11, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) return model def vit_gigantic(patch_size=14, **kwargs): model = VisionTransformer( - patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + patch_size=patch_size, + embed_dim=1664, + depth=48, + num_heads=16, + mpl_ratio=64 / 13, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs ) return model VIT_EMBED_DIMS = { - 'vit_tiny': 192, - 'vit_small': 384, - 'vit_base': 768, - 'vit_large': 1024, - 'vit_huge': 1280, - 'vit_giant': 1408, - 'vit_gigantic': 1664, + "vit_tiny": 192, + "vit_small": 384, + "vit_base": 768, + "vit_large": 1024, + "vit_huge": 1280, + "vit_giant": 1408, + "vit_gigantic": 1664, } diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/utils/distributed.py b/src/utils/distributed.py index cfba444d..a3b7cca1 100644 --- a/src/utils/distributed.py +++ b/src/utils/distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -6,10 +6,12 @@ # import os +import traceback import torch import torch.distributed as dist + from logging import getLogger logger = getLogger() @@ -21,28 +23,26 @@ def init_distributed(port=37123, rank_and_world_size=(None, None)): return dist.get_world_size(), dist.get_rank() rank, world_size = rank_and_world_size - os.environ['MASTER_ADDR'] = 'localhost' + os.environ["MASTER_ADDR"] = "localhost" if (rank is None) or (world_size is None): try: - world_size = int(os.environ['SLURM_NTASKS']) - rank = int(os.environ['SLURM_PROCID']) - os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] + world_size = int(os.environ["SLURM_NTASKS"]) + rank = int(os.environ["SLURM_PROCID"]) + os.environ["MASTER_ADDR"] = os.environ["HOSTNAME"] except Exception: - logger.info('SLURM vars not set (distributed training not available)') + logger.info("SLURM vars not set (distributed training not available)") world_size, rank = 1, 0 return world_size, rank try: - os.environ['MASTER_PORT'] = str(port) + os.environ["MASTER_PORT"] = str(port) torch.distributed.init_process_group( - backend='nccl', - world_size=world_size, - rank=rank + backend="nccl", world_size=world_size, rank=rank ) except Exception as e: world_size, rank = 1, 0 - logger.info(f'Rank: {rank}. Distributed training not available {e}') + logger.info(f"Rank: {rank}. Distributed training not available {traceback.format_exc}") return world_size, rank diff --git a/src/utils/logging.py b/src/utils/logging.py index fcdd3faf..f8039101 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -12,10 +12,10 @@ def gpu_timer(closure, log_timings=True): - """ Helper to time gpu-time to execute closure() """ + """Helper to time gpu-time to execute closure()""" log_timings = log_timings and torch.cuda.is_available() - elapsed_time = -1. + elapsed_time = -1.0 if log_timings: start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -36,8 +36,13 @@ def gpu_timer(closure, log_timings=True): def get_logger(name=None, force=False): - logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format=LOG_FORMAT, + datefmt=DATE_FORMAT, + force=force, + ) return logging.getLogger(name=name) @@ -47,18 +52,18 @@ def __init__(self, fname, *argv): self.fname = fname self.types = [] # -- print headers - with open(self.fname, '+a') as f: + with open(self.fname, "+a") as f: for i, v in enumerate(argv, 1): self.types.append(v[0]) if i < len(argv): - print(v[1], end=',', file=f) + print(v[1], end=",", file=f) else: - print(v[1], end='\n', file=f) + print(v[1], end="\n", file=f) def log(self, *argv): - with open(self.fname, '+a') as f: + with open(self.fname, "+a") as f: for i, tv in enumerate(zip(self.types, argv), 1): - end = ',' if i < len(argv) else '\n' + end = "," if i < len(argv) else "\n" print(tv[0] % tv[1], end=end, file=f) @@ -71,8 +76,8 @@ def __init__(self): def reset(self): self.val = 0 self.avg = 0 - self.max = float('-inf') - self.min = float('inf') + self.max = float("-inf") + self.min = float("inf") self.sum = 0 self.count = 0 @@ -93,26 +98,26 @@ def grad_logger(named_params): stats.first_layer = None stats.last_layer = None for n, p in named_params: - if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): + if (p.grad is not None) and not (n.endswith(".bias") or len(p.shape) == 1): grad_norm = float(torch.norm(p.grad.data)) stats.update(grad_norm) - if 'qkv' in n: + if "qkv" in n: stats.last_layer = grad_norm if stats.first_layer is None: stats.first_layer = grad_norm if stats.first_layer is None or stats.last_layer is None: - stats.first_layer = stats.last_layer = 0. + stats.first_layer = stats.last_layer = 0.0 return stats def adamw_logger(optimizer): - """ logging magnitude of first and second momentum buffers in adamw """ + """logging magnitude of first and second momentum buffers in adamw""" # TODO: assert that optimizer is instance of torch.optim.AdamW - state = optimizer.state_dict().get('state') + state = optimizer.state_dict().get("state") exp_avg_stats = AverageMeter() exp_avg_sq_stats = AverageMeter() for key in state: s = state.get(key) - exp_avg_stats.update(float(s.get('exp_avg').abs().mean())) - exp_avg_sq_stats.update(float(s.get('exp_avg_sq').abs().mean())) - return {'exp_avg': exp_avg_stats, 'exp_avg_sq': exp_avg_sq_stats} + exp_avg_stats.update(float(s.get("exp_avg").abs().mean())) + exp_avg_sq_stats.update(float(s.get("exp_avg_sq").abs().mean())) + return {"exp_avg": exp_avg_stats, "exp_avg_sq": exp_avg_sq_stats} diff --git a/src/utils/monitoring.py b/src/utils/monitoring.py index 95a7845a..505a2a74 100644 --- a/src/utils/monitoring.py +++ b/src/utils/monitoring.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -56,11 +56,12 @@ def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None): if stats_callback_fn is None: # Default callback def stats_callback_fn(resource_sample: ResourceStatsSample): - print( - f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") + print(f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") + elif not callable(stats_callback_fn): - raise ValueError("Callback needs to be callable, got {}".format( - type(stats_callback_fn))) + raise ValueError( + "Callback needs to be callable, got {}".format(type(stats_callback_fn)) + ) self.stats_callback_fn = stats_callback_fn def stop(self) -> None: @@ -121,8 +122,7 @@ def compress_cpu_affinity(cpu_affinity): if min_x == max_x: cpu_affinity_compressed.append("{}".format(min_x)) else: - cpu_affinity_compressed.append( - "{}-{}".format(min_x, max_x)) + cpu_affinity_compressed.append("{}-{}".format(min_x, max_x)) min_x = x max_x = x last_x = x @@ -131,8 +131,7 @@ def compress_cpu_affinity(cpu_affinity): if min_x == max_x: cpu_affinity_compressed.append("{}".format(min_x)) else: - cpu_affinity_compressed.append( - "{}-{}".format(min_x, max_x)) + cpu_affinity_compressed.append("{}-{}".format(min_x, max_x)) # Concat cpu_affinity_compressed = ",".join(cpu_affinity_compressed) @@ -167,6 +166,7 @@ def compress_cpu_affinity(cpu_affinity): if __name__ == "__main__": import multiprocessing import time + pid = multiprocessing.current_process().pid monitor_thread = ResourceMonitoringThread(pid, 1) monitor_thread.start() diff --git a/src/utils/schedulers.py b/src/utils/schedulers.py index df02e2b0..b3496d57 100644 --- a/src/utils/schedulers.py +++ b/src/utils/schedulers.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -18,7 +18,7 @@ def __init__( ref_lr, T_max, last_epoch=-1, - final_lr=0. + final_lr=0.0, ): self.optimizer = optimizer self.start_lr = start_lr @@ -26,7 +26,7 @@ def __init__( self.final_lr = final_lr self.warmup_steps = warmup_steps self.T_max = T_max - warmup_steps - self._step = 0. + self._step = 0.0 def step(self): self._step += 1 @@ -36,34 +36,35 @@ def step(self): else: # -- progress after warmup progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) - new_lr = max(self.final_lr, - self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) + new_lr = max( + self.final_lr, + self.final_lr + + (self.ref_lr - self.final_lr) + * 0.5 + * (1.0 + math.cos(math.pi * progress)), + ) for group in self.optimizer.param_groups: - group['lr'] = new_lr + group["lr"] = new_lr return new_lr class CosineWDSchedule(object): - def __init__( - self, - optimizer, - ref_wd, - T_max, - final_wd=0. - ): + def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0): self.optimizer = optimizer self.ref_wd = ref_wd self.final_wd = final_wd self.T_max = T_max - self._step = 0. + self._step = 0.0 def step(self): self._step += 1 progress = self._step / self.T_max - new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) + new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * ( + 1.0 + math.cos(math.pi * progress) + ) if self.final_wd <= self.ref_wd: new_wd = max(self.final_wd, new_wd) @@ -71,6 +72,6 @@ def step(self): new_wd = min(self.final_wd, new_wd) for group in self.optimizer.param_groups: - if ('WD_exclude' not in group) or not group['WD_exclude']: - group['weight_decay'] = new_wd + if ("WD_exclude" not in group) or not group["WD_exclude"]: + group["weight_decay"] = new_wd return new_wd diff --git a/src/utils/tensors.py b/src/utils/tensors.py index 6ae28509..14369a5b 100644 --- a/src/utils/tensors.py +++ b/src/utils/tensors.py @@ -1,4 +1,4 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) NeoCybernetica, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -14,12 +14,23 @@ logger = getLogger() +def to_batch(images): + """Converts a list of images into a batched tensor. + + Args: + images (list): A list of image tensors, each of shape [C, H, W]. + + Returns: + torch.Tensor: A batched tensor of shape [B, C, H, W], where B is the batch size. + """ + return torch.stack(images, dim=0) + def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 with torch.no_grad(): # Values are generated by using a truncated uniform distribution and @@ -37,7 +48,7 @@ def norm_cdf(x): tensor.erfinv_() # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) + tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range @@ -45,7 +56,7 @@ def norm_cdf(x): return tensor -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # type: (Tensor, float, float, float, float) -> Tensor return _no_grad_trunc_normal_(tensor, mean, std, a, b) @@ -64,8 +75,11 @@ def apply_masks(x, masks): def repeat_interleave_batch(x, B, repeat): N = len(x) // B - x = torch.cat([ - torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) - for i in range(N) - ], dim=0) + x = torch.cat( + [ + torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) + for i in range(N) + ], + dim=0, + ) return x