diff --git a/README.md b/README.md index a3579e10..d985dcd1 100644 --- a/README.md +++ b/README.md @@ -352,32 +352,36 @@ python -m app.main_distributed \ ## Launching Evaluations ### Local training +<<<<<<< HEAD +If you wish to debug your eval code or setup before launching a distributed training run, we provide the functionality to do so by running the pretraining script locally on a multi-GPU (or single-GPU) machine, however, reproducing the full eval would require launching distributed training. +======= If you wish to debug your eval code or setup before launching a distributed training run, we provide the functionality to do so by running the evaluation script locally on a multi-GPU (or single-GPU) machine, however, reproducing the full eval would require launching distributed training. -The single-machine implementation starts from the [eval/main.py](eval/main.py), which parses the experiment config file and runs the eval locally on a multi-GPU (or single-GPU) machine. +>>>>>>> origin/main +The single-machine implementation starts from the [evals/main.py](evals/main.py), which parses the experiment config file and runs the eval locally on a multi-GPU (or single-GPU) machine. -For example, to run ImageNet image classification on GPUs "0", "1", and "2" on a local machine using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: +For example, to run ImageNet image classification on GPUs "0", "1", and "2" on a local machine using the config [configs/evals/vitl16_in1k.yaml](configs/evals/vitl16_in1k.yaml), type the command: ```bash python -m evals.main \ - --fname configs/eval/vitl16_in1k.yaml \ + --fname configs/evals/vitl16_in1k.yaml \ --devices cuda:0 cuda:1 cuda:2 ``` ### Distributed training -To launch a distributed evaluation run, the implementation starts from [eval/main_distributed.py](eval/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. +To launch a distributed evaluation run, the implementation starts from [evals/main_distributed.py](evals/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. -For example, to launch a distributed ImageNet image classification experiment using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: +For example, to launch a distributed ImageNet image classification experiment using the config [configs/evals/vitl16_in1k.yaml](configs/evals/vitl16_in1k.yaml), type the command: ```bash python -m evals.main_distributed \ - --fname configs/eval/vitl16_in1k.yaml \ + --fname configs/evals/vitl16_in1k.yaml \ --folder $path_to_save_stderr_and_stdout \ --partition $slurm_partition ``` -Similarly, to launch a distributed K400 video classification experiment using the config [configs/eval/vitl16_k400.yaml](configs/eval/vitl16_k400.yaml), type the command: +Similarly, to launch a distributed K400 video classification experiment using the config [configs/evals/vitl16_k400_16x8x3.yaml](configs/evals/vitl16_k400_16x8x3.yaml), type the command: ```bash python -m evals.main_distributed \ - --fname configs/eval/vitl16_k400.yaml \ + --fname configs/evals/vitl16_k400_16x8x3.yaml \ --folder $path_to_save_stderr_and_stdout \ --partition $slurm_partition ``` diff --git a/configs/evals/vitl16_k400_16x8x3.yaml b/configs/evals/vitl16_k400_16x8x3.yaml index b7bcf052..0bf44071 100644 --- a/configs/evals/vitl16_k400_16x8x3.yaml +++ b/configs/evals/vitl16_k400_16x8x3.yaml @@ -1,20 +1,21 @@ -nodes: 8 -tasks_per_node: 8 -tag: k400-16x8x3 + +nodes: 1 +tasks_per_node: 1 +tag: pegs-probe-300 eval_name: video_classification_frozen resume_checkpoint: false data: - dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv - dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv + dataset_train: /scratch/ki2130/my-jepa/jepa/p_dummy.csv + dataset_val: /scratch/ki2130/my-jepa/jepa/p_dummy.csv dataset_type: VideoDataset - num_classes: 400 + num_classes: 165 frames_per_clip: 16 num_segments: 8 num_views_per_segment: 3 frame_step: 4 optimization: attend_across_segments: true - num_epochs: 20 + num_epochs: 300 resolution: 224 batch_size: 4 weight_decay: 0.01 @@ -34,6 +35,6 @@ pretrain: tight_silu: false use_sdpa: true patch_size: 16 - folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ - checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder + folder: /scratch/ki2130/ #/your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ + checkpoint: vitl16.pth.tar #jepa-latest.pth.tar # name of pretrained model file inside folder write_tag: jepa diff --git a/evals/main_distributed.py b/evals/main_distributed.py index 1f332a0b..389cf8f6 100644 --- a/evals/main_distributed.py +++ b/evals/main_distributed.py @@ -73,8 +73,8 @@ def checkpoint(self): def launch_evals_with_parsed_args( args_for_evals, submitit_folder, - partition='learnlab,learnfair', - timeout=4300, + # partition='a100_2', + timeout="48:00:00", nodes=1, tasks_per_node=1, delay_seconds=10, @@ -90,13 +90,17 @@ def launch_evals_with_parsed_args( folder=os.path.join(submitit_folder, 'job_%j'), slurm_max_num_timeout=20) executor.update_parameters( - slurm_partition=partition, - slurm_mem_per_gpu='55G', + # slurm_partition=partition, + slurm_mem='128G', timeout_min=timeout, nodes=nodes, tasks_per_node=tasks_per_node, - cpus_per_task=12, - gpus_per_node=tasks_per_node) + cpus_per_task=8, + gpus_per_node=1, + slurm_mail_type='ALL', + slurm_mail_user='ki2130@nyu.edu', + slurm_job_name='model-jepa2') + # slurm_additional_parameters={'gres': 'gpu:a100:1'}) if exclude_nodes is not None: executor.update_parameters(slurm_exclude=exclude_nodes) @@ -149,7 +153,7 @@ def launch_evals(): launch_evals_with_parsed_args( args_for_evals=configs, submitit_folder=args.folder, - partition=args.partition, + # partition=args.partition, timeout=args.time, nodes=nodes, tasks_per_node=tasks_per_node, @@ -160,3 +164,5 @@ def launch_evals(): if __name__ == '__main__': args = parser.parse_args() launch_evals() + print("made it!") + diff --git a/evals/scaffold.py b/evals/scaffold.py index c816b874..b2f10404 100644 --- a/evals/scaffold.py +++ b/evals/scaffold.py @@ -19,6 +19,6 @@ def main( resume_preempt=False ): logger.info(f'Running evaluation: {eval_name}') - return importlib.import_module(f'evals.{eval_name}.eval').main( + return importlib.import_module(f'evals.{eval_name}.pegs_eval').main( args_eval=args_eval, resume_preempt=resume_preempt) diff --git a/evals/video_classification_frozen/kp_utils.py b/evals/video_classification_frozen/kp_utils.py new file mode 100644 index 00000000..21dcd118 --- /dev/null +++ b/evals/video_classification_frozen/kp_utils.py @@ -0,0 +1,30 @@ +import os +import matplotlib.pyplot as plt +from PIL import Image + +def plot_guess_img(input_tensor, output_filename, reference_img='evals/video_classification_frozen/reference_with_border.png', scale=2000): + assert input_tensor.numel() == 165, "input tensor must have exactly 165 entries" + background = Image.open(reference_img) + + plt.figure(figsize=(background.width / 100, background.height / 100), dpi=100) + plt.imshow(background) + plt.axis('off') + + start_x, start_y = 520, 228 + spacing = 63 + scatter_x = [] + scatter_y = [] + sizes = [] + + for idx in range(input_tensor.numel()): + row = idx // 15 + col = idx % 15 + x = start_x + col * spacing + y = start_y + row * spacing + scatter_x.append(x) + scatter_y.append(y) + sizes.append(input_tensor[idx].item() * scale) + + plt.scatter(scatter_x, scatter_y, s=sizes, c='green', alpha=0.4) + plt.savefig(os.path.join('plots/plots-300/', output_filename), bbox_inches='tight', pad_inches=0) + plt.close() diff --git a/evals/video_classification_frozen/pegs_eval.py b/evals/video_classification_frozen/pegs_eval.py new file mode 100644 index 00000000..f48749e2 --- /dev/null +++ b/evals/video_classification_frozen/pegs_eval.py @@ -0,0 +1,611 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS +try: + # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE + # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE + # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE + # -- TO EACH PROCESS + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] +except Exception: + pass + +import logging +import pprint + +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F + +from torch.nn.parallel import DistributedDataParallel + +import src.models.vision_transformer as vit +from src.models.pegs_attentive_probe import PegAttentiveClassifier +from src.datasets.data_manager import ( + init_data, +) +from src.utils.distributed import ( + init_distributed, + AllReduce +) +from src.utils.schedulers import ( + WarmupCosineSchedule, + CosineWDSchedule, +) +from src.utils.logging import ( + AverageMeter, + CSVLogger +) + +from evals.video_classification_frozen.utils import ( + make_transforms, + ClipAggregation, + FrameAggregation +) + +from evals.video_classification_frozen.kp_utils import plot_guess_img + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + +pp = pprint.PrettyPrinter(indent=4) + + +def main(args_eval, resume_preempt=False): + + # ----------------------------------------------------------------------- # + # PASSED IN PARAMS FROM CONFIG FILE + # ----------------------------------------------------------------------- # + + # -- PRETRAIN + args_pretrain = args_eval.get('pretrain') + checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder') + model_name = args_pretrain.get('model_name', None) + patch_size = args_pretrain.get('patch_size', None) + pretrain_folder = args_pretrain.get('folder', None) + ckp_fname = args_pretrain.get('checkpoint', None) + tag = args_pretrain.get('write_tag', None) + use_sdpa = args_pretrain.get('use_sdpa', True) + use_SiLU = args_pretrain.get('use_silu', False) + tight_SiLU = args_pretrain.get('tight_silu', True) + uniform_power = args_pretrain.get('uniform_power', False) + pretrained_path = os.path.join(pretrain_folder, ckp_fname) + # Optional [for Video model]: + tubelet_size = args_pretrain.get('tubelet_size', 2) + pretrain_frames_per_clip = args_pretrain.get('frames_per_clip', 1) + + # -- DATA + args_data = args_eval.get('data') + train_data_path = [args_data.get('dataset_train')] + val_data_path = [args_data.get('dataset_val')] + dataset_type = args_data.get('dataset_type', 'VideoDataset') + num_classes = args_data.get('num_classes') + eval_num_segments = args_data.get('num_segments', 1) + eval_frames_per_clip = args_data.get('frames_per_clip', 16) + eval_frame_step = args_pretrain.get('frame_step', 4) + eval_duration = args_pretrain.get('clip_duration', None) + eval_num_views_per_segment = args_data.get('num_views_per_segment', 1) + + # -- OPTIMIZATION + args_opt = args_eval.get('optimization') + resolution = args_opt.get('resolution', 224) + batch_size = args_opt.get('batch_size') + attend_across_segments = args_opt.get('attend_across_segments', False) + num_epochs = args_opt.get('num_epochs') + wd = args_opt.get('weight_decay') + start_lr = args_opt.get('start_lr') + lr = args_opt.get('lr') + final_lr = args_opt.get('final_lr') + warmup = args_opt.get('warmup') + use_bfloat16 = args_opt.get('use_bfloat16') + + # -- EXPERIMENT-ID/TAG (optional) + resume_checkpoint = args_eval.get('resume_checkpoint', False) or resume_preempt + eval_tag = args_eval.get('tag', None) + + # ----------------------------------------------------------------------- # + + try: + mp.set_start_method('spawn') + except Exception: + pass + + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + world_size, rank = init_distributed() + logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + + # -- log/checkpointing paths + folder = os.path.join(pretrain_folder, 'my-jepa/jepa/' , 'checkpoint-models/') + if eval_tag is not None: + folder = os.path.join(folder, eval_tag) + if not os.path.exists(folder): + os.makedirs(folder, exist_ok=True) + log_file = os.path.join(folder, f'{tag}_r{rank}.csv') + latest_path = os.path.join(folder, f'{tag}-latest.pth.tar') + + # -- make csv_logger + if rank == 0: + csv_logger = CSVLogger(log_file, + ('%d', 'epoch'), + ('%.5f', 'loss'), + ('%.5f', 'acc')) + + # Initialize model + + # -- pretrained encoder (frozen) + encoder = init_model( + crop_size=resolution, + device=device, + pretrained=pretrained_path, + model_name=model_name, + patch_size=patch_size, + tubelet_size=tubelet_size, + frames_per_clip=pretrain_frames_per_clip, + uniform_power=uniform_power, + checkpoint_key=checkpoint_key, + use_SiLU=use_SiLU, + tight_SiLU=tight_SiLU, + use_sdpa=use_sdpa) + if pretrain_frames_per_clip == 1: + # Process each frame independently and aggregate + encoder = FrameAggregation(encoder).to(device) + else: + # Process each video clip independently and aggregate + encoder = ClipAggregation( + encoder, + tubelet_size=tubelet_size, + attend_across_segments=attend_across_segments + ).to(device) + encoder.eval() + for p in encoder.parameters(): + p.requires_grad = False + + # -- init classifier + classifier = PegAttentiveClassifier( + embed_dim=encoder.embed_dim, + num_classes=num_classes, + ).to(device) + + train_loader = make_dataloader( + dataset_type=dataset_type, + root_path=train_data_path, + resolution=resolution, + frames_per_clip=eval_frames_per_clip, + frame_step=eval_frame_step, + eval_duration=eval_duration, + num_segments=eval_num_segments if attend_across_segments else 1, + num_views_per_segment=1, + allow_segment_overlap=True, + batch_size=batch_size, + world_size=world_size, + rank=rank, + training=True) + val_loader = make_dataloader( + dataset_type=dataset_type, + root_path=val_data_path, + resolution=resolution, + frames_per_clip=eval_frames_per_clip, + frame_step=eval_frame_step, + num_segments=eval_num_segments, + eval_duration=eval_duration, + num_views_per_segment=eval_num_views_per_segment, + allow_segment_overlap=True, + batch_size=batch_size, + world_size=world_size, + rank=rank, + training=False) + ipe = len(train_loader) + logger.info(f'Dataloader created... iterations per epoch: {ipe}') + + # -- optimizer and scheduler + optimizer, scaler, scheduler, wd_scheduler = init_opt( + classifier=classifier, + wd=wd, + start_lr=start_lr, + ref_lr=lr, + final_lr=final_lr, + iterations_per_epoch=ipe, + warmup=warmup, + num_epochs=num_epochs, + use_bfloat16=use_bfloat16) + classifier = DistributedDataParallel(classifier, static_graph=True) + + # -- load training checkpoint + start_epoch = 0 + if resume_checkpoint: + classifier, optimizer, scaler, start_epoch = load_checkpoint( + device=device, + r_path=latest_path, + classifier=classifier, + opt=optimizer, + scaler=scaler) + for _ in range(start_epoch*ipe): + scheduler.step() + wd_scheduler.step() + + def save_checkpoint(epoch): + save_dict = { + 'classifier': classifier.state_dict(), + 'opt': optimizer.state_dict(), + 'scaler': None if scaler is None else scaler.state_dict(), + 'epoch': epoch, + 'batch_size': batch_size, + 'world_size': world_size, + 'lr': lr + } + if rank == 0: + torch.save(save_dict, latest_path) + + # TRAIN LOOP + for epoch in range(start_epoch, num_epochs): + logger.info('Epoch %d' % (epoch + 1)) + train_acc = run_one_epoch( + device=device, + training=True, + num_temporal_views=eval_num_segments if attend_across_segments else 1, + attend_across_segments=attend_across_segments, + num_spatial_views=1, + encoder=encoder, + classifier=classifier, + scaler=scaler, + optimizer=optimizer, + scheduler=scheduler, + wd_scheduler=wd_scheduler, + data_loader=train_loader, + use_bfloat16=use_bfloat16, + training_losses) # Pass the training_losses array + + val_acc = run_one_epoch( + device=device, + training=False, + num_temporal_views=eval_num_segments, + attend_across_segments=attend_across_segments, + num_spatial_views=eval_num_views_per_segment, + encoder=encoder, + classifier=classifier, + scaler=scaler, + optimizer=optimizer, + scheduler=scheduler, + wd_scheduler=wd_scheduler, + data_loader=val_loader, + use_bfloat16=use_bfloat16, + testing_losses) # Pass the testing_losses array + + logger.info('[%5d] train: %.3f test: %.3f' % (epoch + 1, train_acc, val_acc)) + # logger.info('[%5d] train: %.3f%% test: %.3f%%' % (epoch + 1, train_acc, val_acc)) + if rank == 0: + csv_logger.log(epoch + 1, train_acc, val_acc) + save_checkpoint(epoch + 1) + + # Save the numpy arrays after all epochs + np.save('training_losses.npy', training_losses) + np.save('testing_losses.npy', testing_losses) + +# Initialize numpy arrays outside the function to store losses across all epochs +training_losses = np.array([]) +testing_losses = np.array([]) + +def run_one_epoch( + device, + training, + encoder, + classifier, + scaler, + optimizer, + scheduler, + wd_scheduler, + data_loader, + use_bfloat16, + num_spatial_views, + num_temporal_views, + attend_across_segments, + loss_array, # add this to keep track of loss +): + + classifier.train(mode=training) + criterion = torch.nn.BCEWithLogitsLoss() + top1_meter = AverageMeter() + for itr, data in enumerate(data_loader): + + if training: + scheduler.step() + wd_scheduler.step() + + with torch.cuda.amp.autocast(dtype=torch.float16, enabled=use_bfloat16): + + # Load data and put on GPU + clips = [ + [dij.to(device, non_blocking=True) for dij in di] # iterate over spatial views of clip + for di in data[0] # iterate over temporal index of clip + ] + clip_indices = [d.to(device, non_blocking=True) for d in data[2]] + labels = data[1].to(device) + batch_size = len(labels) + + print("is it training?",training) + + # Forward and prediction + with torch.no_grad(): + outputs = encoder(clips, clip_indices) + # print("outputs with encoder applied to clips and clips indices shape:", outputs[0].shape) + if not training: + if attend_across_segments: + outputs = [classifier(o) for o in outputs] + #print("outputs shape", outputs[0].shape) + else: + outputs = [[classifier(ost) for ost in os] for os in outputs] + #print("ouputs shape", outputs[0].shape) + if training: + if attend_across_segments: + outputs = [classifier(o) for o in outputs] + # print("outputs attend:", outputs[0].shape) + else: + outputs = [[classifier(ost) for ost in os] for os in outputs] + # print("outputs NOT attend:", outputs[0].shape) + # print("outputs is:", outputs) + # print("outputs[0] final shape", outputs[0].shape) + # print("labels[0] final shape:", labels[0].shape) + + # save output and label as images (comment this out when done testing) + # plot_guess_img(outputs[0][0,:], output_filename = 'outputs-0.png') + + sigmoid = torch.nn.Sigmoid() + sigmoid_outputs = sigmoid(outputs[0][0,:]) + sigmoid_outputs = sigmoid_outputs.squeeze(0) + # print("sigmoid outouts:", sigmoid_outputs) + # print("sigmoid_outputs shape:", sigmoid_outputs.shape) + plot_guess_img(sigmoid_outputs, output_filename = 'outputs-0.png') + plot_guess_img(labels[0], output_filename = 'labels-0.png') + + # print("PRINT outputs[0][0:]", outputs[0][0:]) + # print("PRINT 'labels[0]",labels[0]) + # print("PRINT length labels", len(labels)) + + # Compute loss + if attend_across_segments: + loss = 0 + for i in range(len(labels)): + # print("PRINT outputs[0][0:]", outputs[0][0:]) + # print("PRINT 'labels[0]",labels[0]) + # print("PRINT length labels", len(labels)) + loss+=criterion(outputs[0][i,:], labels[i].unsqueeze(0)) + loss = loss/len(labels) + # loss = sum([criterion(o, labels) for o in outputs]) / len(outputs) + #else: + #loss = sum([sum([criterion(ost, labels) for ost in os]) for os in outputs]) / len(outputs) / len(outputs[0]) + with torch.no_grad(): + # kat + if attend_across_segments: + sum_softmax = 0 + # print("output[0].shape[0]",outputs[0].shape[0]) + for i in range(outputs[0].shape[0]): + # print("PRINT outputs[0][0:]", outputs[0][0:]) + # print("PRINT 'labels[0]",labels[0]) + # print("PRINT length labels", len(labels)) + sum_softmax += F.softmax(outputs[0][i,:], dim=0) # no averaging (dividing by len(outputs)) + outputs = sum_softmax + + if training: + if use_bfloat16: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() + + # Append loss to the appropriate numpy array + loss_array = np.append(loss_array, loss.item()) + + if itr % 20 == 0: + logger.info('[%5d] (loss: %.3f) [mem: %.2e]' + % (itr, loss, + torch.cuda.max_memory_allocated() / 1024.**2)) + + return loss_array + + + +def load_checkpoint( + device, + r_path, + classifier, + opt, + scaler +): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + epoch = checkpoint['epoch'] + + # -- loading encoder + pretrained_dict = checkpoint['classifier'] + msg = classifier.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained classifier from epoch {epoch} with msg: {msg}') + + # -- loading optimizer + opt.load_state_dict(checkpoint['opt']) + if scaler is not None: + scaler.load_state_dict(checkpoint['scaler']) + logger.info(f'loaded optimizers from epoch {epoch}') + logger.info(f'read-path: {r_path}') + del checkpoint + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return classifier, opt, scaler, epoch + + +def load_pretrained( + encoder, + pretrained, + checkpoint_key='target_encoder' +): + logger.info(f'Loading pretrained model from {pretrained}') + checkpoint = torch.load(pretrained, map_location='cpu') + try: + pretrained_dict = checkpoint[checkpoint_key] + except Exception: + pretrained_dict = checkpoint['encoder'] + + pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()} + pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()} + for k, v in encoder.state_dict().items(): + if k not in pretrained_dict: + logger.info(f'key "{k}" could not be found in loaded state dict') + elif pretrained_dict[k].shape != v.shape: + logger.info(f'key "{k}" is of different shape in model and loaded state dict') + pretrained_dict[k] = v + msg = encoder.load_state_dict(pretrained_dict, strict=False) + print(encoder) + logger.info(f'loaded pretrained model with msg: {msg}') + logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}') + del checkpoint + return encoder + + +def make_dataloader( + root_path, + batch_size, + world_size, + rank, + dataset_type='VideoDataset', + resolution=224, + frames_per_clip=16, + frame_step=4, + num_segments=8, + eval_duration=None, + num_views_per_segment=1, + allow_segment_overlap=True, + training=False, + num_workers=8, + subset_file=None +): + # Make Video Transforms + transform = make_transforms( + training=training, + num_views_per_clip=num_views_per_segment, + random_horizontal_flip=False, + random_resize_aspect_ratio=(0.75, 4/3), + random_resize_scale=(0.08, 1.0), + reprob=0.25, + auto_augment=True, + motion_shift=False, + crop_size=resolution, + ) + + data_loader, _ = init_data( + data=dataset_type, + root_path=root_path, + transform=transform, + batch_size=batch_size, + world_size=world_size, + rank=rank, + clip_len=frames_per_clip, + frame_sample_rate=frame_step, + duration=eval_duration, + num_clips=num_segments, + allow_clip_overlap=allow_segment_overlap, + num_workers=num_workers, + copy_data=False, + drop_last=False, + subset_file=subset_file) + return data_loader + + +def init_model( + device, + pretrained, + model_name, + patch_size=16, + crop_size=224, + # Video specific parameters + frames_per_clip=16, + tubelet_size=2, + use_sdpa=False, + use_SiLU=False, + tight_SiLU=True, + uniform_power=False, + checkpoint_key='target_encoder' +): + encoder = vit.__dict__[model_name]( + img_size=crop_size, + patch_size=patch_size, + num_frames=frames_per_clip, + tubelet_size=tubelet_size, + uniform_power=uniform_power, + use_sdpa=use_sdpa, + use_SiLU=use_SiLU, + tight_SiLU=tight_SiLU, + ) + + encoder.to(device) + encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key) + return encoder + + +def init_opt( + classifier, + iterations_per_epoch, + start_lr, + ref_lr, + warmup, + num_epochs, + wd=1e-6, + final_wd=1e-6, + final_lr=0.0, + use_bfloat16=False +): + param_groups = [ + { + 'params': (p for n, p in classifier.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in classifier.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': True, + 'weight_decay': 0 + } + ] + + logger.info('Using AdamW') + optimizer = torch.optim.AdamW(param_groups) + scheduler = WarmupCosineSchedule( + optimizer, + warmup_steps=int(warmup*iterations_per_epoch), + start_lr=start_lr, + ref_lr=ref_lr, + final_lr=final_lr, + T_max=int(num_epochs*iterations_per_epoch)) + wd_scheduler = CosineWDSchedule( + optimizer, + ref_wd=wd, + final_wd=final_wd, + T_max=int(num_epochs*iterations_per_epoch)) + scaler = torch.cuda.amp.GradScaler() if use_bfloat16 else None + return optimizer, scaler, scheduler, wd_scheduler diff --git a/evals/video_classification_frozen/pegs_eval2.py b/evals/video_classification_frozen/pegs_eval2.py new file mode 100644 index 00000000..89b92816 --- /dev/null +++ b/evals/video_classification_frozen/pegs_eval2.py @@ -0,0 +1,588 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS +try: + # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE + # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE + # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE + # -- TO EACH PROCESS + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] +except Exception: + pass + +import logging +import pprint + +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F + +from torch.nn.parallel import DistributedDataParallel + +import src.models.vision_transformer as vit +from src.models.pegs_attentive_probe import PegAttentiveClassifier +from src.datasets.data_manager import ( + init_data, +) +from src.utils.distributed import ( + init_distributed, + AllReduce +) +from src.utils.schedulers import ( + WarmupCosineSchedule, + CosineWDSchedule, +) +from src.utils.logging import ( + AverageMeter, + CSVLogger +) + +from evals.video_classification_frozen.utils import ( + make_transforms, + ClipAggregation, + FrameAggregation +) + +from evals.video_classification_frozen.kp_utils import plot_guess_img + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + +pp = pprint.PrettyPrinter(indent=4) + + +def main(args_eval, resume_preempt=False): + + # ----------------------------------------------------------------------- # + # PASSED IN PARAMS FROM CONFIG FILE + # ----------------------------------------------------------------------- # + + # -- PRETRAIN + args_pretrain = args_eval.get('pretrain') + checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder') + model_name = args_pretrain.get('model_name', None) + patch_size = args_pretrain.get('patch_size', None) + pretrain_folder = args_pretrain.get('folder', None) + ckp_fname = args_pretrain.get('checkpoint', None) + tag = args_pretrain.get('write_tag', None) + use_sdpa = args_pretrain.get('use_sdpa', True) + use_SiLU = args_pretrain.get('use_silu', False) + tight_SiLU = args_pretrain.get('tight_silu', True) + uniform_power = args_pretrain.get('uniform_power', False) + pretrained_path = os.path.join(pretrain_folder, ckp_fname) + # Optional [for Video model]: + tubelet_size = args_pretrain.get('tubelet_size', 2) + pretrain_frames_per_clip = args_pretrain.get('frames_per_clip', 1) + + # -- DATA + args_data = args_eval.get('data') + train_data_path = [args_data.get('dataset_train')] + val_data_path = [args_data.get('dataset_val')] + dataset_type = args_data.get('dataset_type', 'VideoDataset') + num_classes = args_data.get('num_classes') + eval_num_segments = args_data.get('num_segments', 1) + eval_frames_per_clip = args_data.get('frames_per_clip', 16) + eval_frame_step = args_pretrain.get('frame_step', 4) + eval_duration = args_pretrain.get('clip_duration', None) + eval_num_views_per_segment = args_data.get('num_views_per_segment', 1) + + # -- OPTIMIZATION + args_opt = args_eval.get('optimization') + resolution = args_opt.get('resolution', 224) + batch_size = args_opt.get('batch_size') + attend_across_segments = args_opt.get('attend_across_segments', False) + num_epochs = args_opt.get('num_epochs') + wd = args_opt.get('weight_decay') + start_lr = args_opt.get('start_lr') + lr = args_opt.get('lr') + final_lr = args_opt.get('final_lr') + warmup = args_opt.get('warmup') + use_bfloat16 = args_opt.get('use_bfloat16') + + # -- EXPERIMENT-ID/TAG (optional) + resume_checkpoint = args_eval.get('resume_checkpoint', False) or resume_preempt + eval_tag = args_eval.get('tag', None) + + # ----------------------------------------------------------------------- # + + try: + mp.set_start_method('spawn') + except Exception: + pass + + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + world_size, rank = init_distributed() + logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + + # -- log/checkpointing paths + folder = os.path.join(pretrain_folder, 'video_classification_frozen/') + if eval_tag is not None: + folder = os.path.join(folder, eval_tag) + if not os.path.exists(folder): + os.makedirs(folder, exist_ok=True) + log_file = os.path.join(folder, f'{tag}_r{rank}.csv') + latest_path = os.path.join(folder, f'{tag}-latest.pth.tar') + + # -- make csv_logger + if rank == 0: + csv_logger = CSVLogger(log_file, + ('%d', 'epoch'), + ('%.5f', 'loss'), + ('%.5f', 'acc')) + + # Initialize model + + # -- pretrained encoder (frozen) + encoder = init_model( + crop_size=resolution, + device=device, + pretrained=pretrained_path, + model_name=model_name, + patch_size=patch_size, + tubelet_size=tubelet_size, + frames_per_clip=pretrain_frames_per_clip, + uniform_power=uniform_power, + checkpoint_key=checkpoint_key, + use_SiLU=use_SiLU, + tight_SiLU=tight_SiLU, + use_sdpa=use_sdpa) + if pretrain_frames_per_clip == 1: + # Process each frame independently and aggregate + encoder = FrameAggregation(encoder).to(device) + else: + # Process each video clip independently and aggregate + encoder = ClipAggregation( + encoder, + tubelet_size=tubelet_size, + attend_across_segments=attend_across_segments + ).to(device) + encoder.eval() + for p in encoder.parameters(): + p.requires_grad = False + + # -- init classifier + classifier = PegAttentiveClassifier( + embed_dim=encoder.embed_dim, + num_classes=num_classes, + ).to(device) + + train_loader = make_dataloader( + dataset_type=dataset_type, + root_path=train_data_path, + resolution=resolution, + frames_per_clip=eval_frames_per_clip, + frame_step=eval_frame_step, + eval_duration=eval_duration, + num_segments=eval_num_segments if attend_across_segments else 1, + num_views_per_segment=1, + allow_segment_overlap=True, + batch_size=batch_size, + world_size=world_size, + rank=rank, + training=True) + val_loader = make_dataloader( + dataset_type=dataset_type, + root_path=val_data_path, + resolution=resolution, + frames_per_clip=eval_frames_per_clip, + frame_step=eval_frame_step, + num_segments=eval_num_segments, + eval_duration=eval_duration, + num_views_per_segment=eval_num_views_per_segment, + allow_segment_overlap=True, + batch_size=batch_size, + world_size=world_size, + rank=rank, + training=False) + ipe = len(train_loader) + logger.info(f'Dataloader created... iterations per epoch: {ipe}') + + # -- optimizer and scheduler + optimizer, scaler, scheduler, wd_scheduler = init_opt( + classifier=classifier, + wd=wd, + start_lr=start_lr, + ref_lr=lr, + final_lr=final_lr, + iterations_per_epoch=ipe, + warmup=warmup, + num_epochs=num_epochs, + use_bfloat16=use_bfloat16) + classifier = DistributedDataParallel(classifier, static_graph=True) + + # -- load training checkpoint + start_epoch = 0 + if resume_checkpoint: + classifier, optimizer, scaler, start_epoch = load_checkpoint( + device=device, + r_path=latest_path, + classifier=classifier, + opt=optimizer, + scaler=scaler) + for _ in range(start_epoch*ipe): + scheduler.step() + wd_scheduler.step() + + def save_checkpoint(epoch): + save_dict = { + 'classifier': classifier.state_dict(), + 'opt': optimizer.state_dict(), + 'scaler': None if scaler is None else scaler.state_dict(), + 'epoch': epoch, + 'batch_size': batch_size, + 'world_size': world_size, + 'lr': lr + } + if rank == 0: + torch.save(save_dict, latest_path) + + # TRAIN LOOP + for epoch in range(start_epoch, num_epochs): + logger.info('Epoch %d' % (epoch + 1)) + train_acc = run_one_epoch( + device=device, + training=True, + num_temporal_views=eval_num_segments if attend_across_segments else 1, + attend_across_segments=attend_across_segments, + num_spatial_views=1, + encoder=encoder, + classifier=classifier, + scaler=scaler, + optimizer=optimizer, + scheduler=scheduler, + wd_scheduler=wd_scheduler, + data_loader=train_loader, + use_bfloat16=use_bfloat16) + + val_acc = run_one_epoch( + device=device, + training=False, + num_temporal_views=eval_num_segments, + attend_across_segments=attend_across_segments, + num_spatial_views=eval_num_views_per_segment, + encoder=encoder, + classifier=classifier, + scaler=scaler, + optimizer=optimizer, + scheduler=scheduler, + wd_scheduler=wd_scheduler, + data_loader=val_loader, + use_bfloat16=use_bfloat16) + + logger.info('[%5d] train: %.3f test: %.3f' % (epoch + 1, train_acc, val_acc)) + # logger.info('[%5d] train: %.3f%% test: %.3f%%' % (epoch + 1, train_acc, val_acc)) + if rank == 0: + csv_logger.log(epoch + 1, train_acc, val_acc) + save_checkpoint(epoch + 1) + + +def run_one_epoch( + device, + training, + encoder, + classifier, + scaler, + optimizer, + scheduler, + wd_scheduler, + data_loader, + use_bfloat16, + num_spatial_views, + num_temporal_views, + attend_across_segments, +): + + classifier.train(mode=training) + criterion = torch.nn.CrossEntropyLoss() + top1_meter = AverageMeter() + for itr, data in enumerate(data_loader): + + if training: + scheduler.step() + wd_scheduler.step() + + with torch.cuda.amp.autocast(dtype=torch.float16, enabled=use_bfloat16): + + # Load data and put on GPU + clips = [ + [dij.to(device, non_blocking=True) for dij in di] # iterate over spatial views of clip + for di in data[0] # iterate over temporal index of clip + ] + clip_indices = [d.to(device, non_blocking=True) for d in data[2]] + labels = data[1].to(device) + batch_size = len(labels) + + print("is it training?",training) + + # Forward and prediction + with torch.no_grad(): + outputs = encoder(clips, clip_indices) + # print("outputs with encoder applied to clips and clips indices shape:", outputs[0].shape) + if not training: + if attend_across_segments: + outputs = [classifier(o) for o in outputs] + #print("outputs shape", outputs[0].shape) + else: + outputs = [[classifier(ost) for ost in os] for os in outputs] + #print("ouputs shape", outputs[0].shape) + if training: + if attend_across_segments: + outputs = [classifier(o) for o in outputs] + # print("outputs attend:", outputs[0].shape) + else: + outputs = [[classifier(ost) for ost in os] for os in outputs] + # print("outputs NOT attend:", outputs[0].shape) + # print("outputs is:", outputs) + print("outputs[0] final shape", outputs[0].shape) + print("labels[0] final shape:", labels[0].shape) + + # save output and label as images (comment this out when done testing) + plot_guess_img(outputs[0][0,:], output_filename = 'outputs-0.png') + plot_guess_img(labels[0], output_filename = 'labels-0.png') + + + #print("PRINT outputs[0][0:]", outputs[0][0:]) + #print("PRINT 'labels[0]",labels[0]) + #print("PRINT length labels", len(labels)) + + # Compute loss + if attend_across_segments: + loss = sum([criterion(o, labels) for o in outputs]) / len(outputs) + else: + loss = sum([sum([criterion(ost, labels) for ost in os]) for os in outputs]) / len(outputs) / len(outputs[0]) + with torch.no_grad(): + if attend_across_segments: + outputs = sum([F.softmax(o, dim=1) for o in outputs]) / len(outputs) + else: + outputs = sum([sum([F.softmax(ost, dim=1) for ost in os]) for os in outputs]) / len(outputs) / len(outputs[0]) + top1_acc = 100. * outputs.max(dim=1).indices.eq(labels).sum() / batch_size + top1_acc = float(AllReduce.apply(top1_acc)) + top1_meter.update(top1_acc) + + if training: + if use_bfloat16: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0) + scaler.step(optimizer) + scaler.update() + for name, param in model.named_parameters(): + if param.grad is not None: + print(f"Parameter: {name}, Gradient norm: {param.grad.norm().item()}") + else: + loss.backward() + torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0) + optimizer.step() + for name, param in model.named_parameters(): + if param.grad is not None: + print(f"Parameter: {name}, Gradient norm: {param.grad.norm().item()}") + optimizer.zero_grad() + + if itr % 20 == 0: + logger.info('[%5d] %.3f%% (loss: %.3f) [mem: %.2e]' + % (itr, top1_meter.avg, loss, + torch.cuda.max_memory_allocated() / 1024.**2)) + + return top1_meter.avg + + + +def load_checkpoint( + device, + r_path, + classifier, + opt, + scaler +): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + epoch = checkpoint['epoch'] + + # -- loading encoder + pretrained_dict = checkpoint['classifier'] + msg = classifier.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained classifier from epoch {epoch} with msg: {msg}') + + # -- loading optimizer + opt.load_state_dict(checkpoint['opt']) + if scaler is not None: + scaler.load_state_dict(checkpoint['scaler']) + logger.info(f'loaded optimizers from epoch {epoch}') + logger.info(f'read-path: {r_path}') + del checkpoint + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return classifier, opt, scaler, epoch + + +def load_pretrained( + encoder, + pretrained, + checkpoint_key='target_encoder' +): + logger.info(f'Loading pretrained model from {pretrained}') + checkpoint = torch.load(pretrained, map_location='cpu') + try: + pretrained_dict = checkpoint[checkpoint_key] + except Exception: + pretrained_dict = checkpoint['encoder'] + + pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()} + pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()} + for k, v in encoder.state_dict().items(): + if k not in pretrained_dict: + logger.info(f'key "{k}" could not be found in loaded state dict') + elif pretrained_dict[k].shape != v.shape: + logger.info(f'key "{k}" is of different shape in model and loaded state dict') + pretrained_dict[k] = v + msg = encoder.load_state_dict(pretrained_dict, strict=False) + print(encoder) + logger.info(f'loaded pretrained model with msg: {msg}') + logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}') + del checkpoint + return encoder + + +def make_dataloader( + root_path, + batch_size, + world_size, + rank, + dataset_type='VideoDataset', + resolution=224, + frames_per_clip=16, + frame_step=4, + num_segments=8, + eval_duration=None, + num_views_per_segment=1, + allow_segment_overlap=True, + training=False, + num_workers=8, + subset_file=None +): + # Make Video Transforms + transform = make_transforms( + training=training, + num_views_per_clip=num_views_per_segment, + random_horizontal_flip=False, + random_resize_aspect_ratio=(0.75, 4/3), + random_resize_scale=(0.08, 1.0), + reprob=0.25, + auto_augment=True, + motion_shift=False, + crop_size=resolution, + ) + + data_loader, _ = init_data( + data=dataset_type, + root_path=root_path, + transform=transform, + batch_size=batch_size, + world_size=world_size, + rank=rank, + clip_len=frames_per_clip, + frame_sample_rate=frame_step, + duration=eval_duration, + num_clips=num_segments, + allow_clip_overlap=allow_segment_overlap, + num_workers=num_workers, + copy_data=False, + drop_last=False, + subset_file=subset_file) + return data_loader + + +def init_model( + device, + pretrained, + model_name, + patch_size=16, + crop_size=224, + # Video specific parameters + frames_per_clip=16, + tubelet_size=2, + use_sdpa=False, + use_SiLU=False, + tight_SiLU=True, + uniform_power=False, + checkpoint_key='target_encoder' +): + encoder = vit.__dict__[model_name]( + img_size=crop_size, + patch_size=patch_size, + num_frames=frames_per_clip, + tubelet_size=tubelet_size, + uniform_power=uniform_power, + use_sdpa=use_sdpa, + use_SiLU=use_SiLU, + tight_SiLU=tight_SiLU, + ) + + encoder.to(device) + encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key) + return encoder + + +def init_opt( + classifier, + iterations_per_epoch, + start_lr, + ref_lr, + warmup, + num_epochs, + wd=1e-6, + final_wd=1e-6, + final_lr=0.0, + use_bfloat16=False +): + param_groups = [ + { + 'params': (p for n, p in classifier.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in classifier.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': True, + 'weight_decay': 0 + } + ] + + logger.info('Using AdamW') + optimizer = torch.optim.AdamW(param_groups) + scheduler = WarmupCosineSchedule( + optimizer, + warmup_steps=int(warmup*iterations_per_epoch), + start_lr=start_lr, + ref_lr=ref_lr, + final_lr=final_lr, + T_max=int(num_epochs*iterations_per_epoch)) + wd_scheduler = CosineWDSchedule( + optimizer, + ref_wd=wd, + final_wd=final_wd, + T_max=int(num_epochs*iterations_per_epoch)) + scaler = torch.cuda.amp.GradScaler() if use_bfloat16 else None + return optimizer, scaler, scheduler, wd_scheduler diff --git a/evals/video_classification_frozen/reference_with_border.png b/evals/video_classification_frozen/reference_with_border.png new file mode 100644 index 00000000..9a2b6968 Binary files /dev/null and b/evals/video_classification_frozen/reference_with_border.png differ diff --git a/p_mini_test.csv b/p_mini_test.csv new file mode 100644 index 00000000..71e34e23 --- /dev/null +++ b/p_mini_test.csv @@ -0,0 +1,3 @@ +/scratch/ki2130/my-jepa/jepa/pendulum/mini_test/1_20-04_957-541.mp4 /scratch/ki2130/my-jepa/jepa/video_answers/video-answers-1d/1_20-04_957-541.npy +/scratch/ki2130/my-jepa/jepa/pendulum/mini_test/1_21-04_957-541.mp4 /scratch/ki2130/my-jepa/jepa/video_answers/video-answers-1d/1_21-04_957-541.npy + diff --git a/p_mini_train.csv b/p_mini_train.csv new file mode 100644 index 00000000..18d4699a --- /dev/null +++ b/p_mini_train.csv @@ -0,0 +1,5 @@ +/scratch/ki2130/my-jepa/jepa/pendulum/mini_train/1_09-04_961-541.mp4 /scratch/ki2130/my-jepa/jepa/video_answers/video-answers-1d/1_09-04_961-541.npy +/scratch/ki2130/my-jepa/jepa/pendulum/mini_train/1_09-05_960-541.mp4 /scratch/ki2130/my-jepa/jepa/video_answers/video-answers-1d/1_09-05_960-541.npy +/scratch/ki2130/my-jepa/jepa/pendulum/mini_train/1_09-06_960-541.mp4 /scratch/ki2130/my-jepa/jepa/video_answers/video-answers-1d/1_09-06_960-541.npy +/scratch/ki2130/my-jepa/jepa/pendulum/mini_train/1_13-07_960-541.mp4 /scratch/ki2130/my-jepa/jepa/video_answers/video-answers-1d/1_13-07_960-541.npy +/scratch/ki2130/my-jepa/jepa/pendulum/mini_train/1_17-10_960-541.mp4 /scratch/ki2130/my-jepa/jepa/video_answers/video-answers-1d/1_17-10_960-541.npy diff --git a/src/datasets/video_dataset.py b/src/datasets/video_dataset.py index b05cc701..ffafb4ff 100644 --- a/src/datasets/video_dataset.py +++ b/src/datasets/video_dataset.py @@ -142,6 +142,12 @@ def __init__( num_samples = len(data) self.num_samples_per_dataset.append(len(data)) + elif data_path[-11:] == '_directory/': + samples = [file for file in os.listdir(data_path) if file.endswith('.mp4')] + labels = [np.load(file) for file in os.listdir(data_path) if file.endswith('.npy')] + num_samples = len(samples) + self.num_samples_per_dataset.append(num_samples) + # [Optional] Weights for each sample to be used by downstream # weighted video sampler self.sample_weights = None @@ -152,6 +158,22 @@ def __init__( self.samples = samples self.labels = labels + + # kat + def load_label_from_file(self, label_path): + # Load NumPy array from file + label_array = np.load(label_path) + + # Convert NumPy array to PyTorch tensor + label_tensor = torch.from_numpy(label_array) + + # Remove the extra dimension using torch.squeeze() + label_tensor = label_tensor.squeeze(dim=-1) + + # Convert type + label_tensor = label_tensor.to(dtype=torch.float32) + + return label_tensor def __getitem__(self, index): sample = self.samples[index] @@ -167,6 +189,8 @@ def __getitem__(self, index): # Label/annotations for video label = self.labels[index] + # kat + label_tensor = self.load_label_from_file(label) def split_into_clips(video): """ Split video into a list of clips """ @@ -181,7 +205,8 @@ def split_into_clips(video): if self.transform is not None: buffer = [self.transform(clip) for clip in buffer] - return buffer, label, clip_indices + # kat + return buffer, label_tensor, clip_indices def loadvideo_decord(self, sample): """ Load video content using Decord """ diff --git a/src/models/pegs_attentive_probe.py b/src/models/pegs_attentive_probe.py new file mode 100644 index 00000000..ac6e5204 --- /dev/null +++ b/src/models/pegs_attentive_probe.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn + + +class PegAttentiveClassifier(nn.Module): + """ Attentive Classifier """ + def __init__( + self, + embed_dim=768, # note that the embed dim gets set from the encoder parameters (vit) + num_classes=165 + ): + super().__init__() + # self.linear = nn.Linear(12544*embed_dim, num_classes, bias=False) + # self.linear = nn.Linear(embed_dim, num_classes, bias=True) + self.linear = nn.Linear(330, num_classes, bias=True) # 1024 becomes 98 after the avgpool1d + # self.softmax = nn.Softmax() + # self.avgpool = nn.AvgPool1d(50, stride=10) + self.adaptivepool = nn.AdaptiveAvgPool2d((1, 330)) + + def forward(self, x): + print("input to classifier shape:", x.shape) + # x = torch.sum(x, dim=1) + # print("summed x:", x.shape) + # print("min after sum", torch.min(x)) + # print("max after sum", torch.max(x)) + ## x = self.softmax(x) + ## print("min after softmax", torch.min(x)) + ## print("max after softmax", torch.max(x)) + # x = self.avgpool(x) + # print("min after avgpool1d", torch.min(x)) + # print("max after avgpool1d", torch.max(x)) + x = self.adaptivepool(x) + # print("min after adaptivepool", torch.min(x)) + # print("max after adaptivepool", torch.max(x)) + x = self.linear(x) + # print("min after linear", torch.min(x)) + # print("max after linear", torch.max(x)) + return x + + # flattened_x = x.flatten(1,-1) + # print("flattened x:", flattened_x.shape) + # x = self.linear(flattened_x) diff --git a/src/models/vision_transformer.py b/src/models/vision_transformer.py index a8748dfd..c85b150e 100644 --- a/src/models/vision_transformer.py +++ b/src/models/vision_transformer.py @@ -173,7 +173,7 @@ def forward(self, x, masks=None): if pos_embed is not None: x += pos_embed B, N, D = x.shape - + # print("pos_embed x shape:", x.shape) # Mask away unwanted tokens (if masks provided) if masks is not None: x = apply_masks(x, masks) @@ -191,7 +191,9 @@ def forward(self, x, masks=None): if self.norm is not None: x = self.norm(x) - + print("shape of encoder:", x.shape) + # print("min of encoded", torch.min(x)) + # print("max of encoded", torch.max(x)) return x def interpolate_pos_encoding(self, x, pos_embed):