diff --git a/MethaneModel/Gas-ViT-8-classes.ipynb b/MethaneModel/Gas-ViT-8-classes.ipynb new file mode 100644 index 0000000..29b1353 --- /dev/null +++ b/MethaneModel/Gas-ViT-8-classes.ipynb @@ -0,0 +1,888 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "F_-w0DYiVX7g" + }, + "source": [ + "### This is a Pytorch version of the work, for easier time working with ViT" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qSgTfcIPVX7l" + }, + "source": [ + "# Step 1: Load and Preprocess the Dataset\n", + "\n", + "### Load the GasVid dataset\n", + "### Preprocess the data\n", + "### Split the dataset into training and test sets" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "%pip install --upgrade pip\n", + "%pip install -r ../requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# if apple and want MPS acceleration do this\n", + "# %%capture\n", + "# %pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "executionInfo": { + "elapsed": 10680, + "status": "ok", + "timestamp": 1696464635710, + "user": { + "displayName": "Angeline Lee", + "userId": "12532490570362671000" + }, + "user_tz": 420 + }, + "id": "JObjEqgBVX7n" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-03-19 16:28:23.512869: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-03-19 16:28:23.512900: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-03-19 16:28:23.512920: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" + ] + } + ], + "source": [ + "# Imports\n", + "\n", + "import os\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import cv2\n", + "import torch\n", + "import torch.nn as nn\n", + "from torchvision import transforms\n", + "from transformers import ViTForImageClassification, ViTImageProcessor, ViTConfig\n", + "from torch.utils.data import DataLoader, Dataset, Subset\n", + "from tqdm.notebook import tqdm\n", + "import pandas as pd\n", + "\n", + "from sklearn.metrics import RocCurveDisplay, roc_curve, ConfusionMatrixDisplay, confusion_matrix\n", + "\n", + "from sklearn.mixture import GaussianMixture\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3-rMPE8CVX7p" + }, + "source": [ + "### Setting up Directories" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "xA87iKePVX7p" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "C0 Training Data Count: 42902\n", + "C0 Testing Data Count: 31002\n", + "C1 Training Data Count: 42913\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "C1 Testing Data Count: 31010\n", + "C2 Training Data Count: 42874\n", + "C2 Testing Data Count: 30996\n", + "C3 Training Data Count: 42897\n", + "C3 Testing Data Count: 31002\n", + "C4 Training Data Count: 42898\n", + "C4 Testing Data Count: 31013\n", + "C5 Training Data Count: 42889\n", + "C5 Testing Data Count: 30991\n", + "C6 Training Data Count: 42907\n", + "C6 Testing Data Count: 31004\n", + "C7 Training Data Count: 42910\n", + "C7 Testing Data Count: 30953\n" + ] + } + ], + "source": [ + "# get generic path to directory\n", + "dir_path = os.path.dirname(os.path.realpath(\"__file__\"))\n", + "\n", + "# get all raw video data directories\n", + "data_dir = os.path.join(dir_path, 'data')\n", + "\n", + "train_data_dir = os.path.join(data_dir, 'train')\n", + "test_data_dir = os.path.join(data_dir, 'test')\n", + "\n", + "frame_data_dir = os.path.join(dir_path, 'background_sub_movingavg8_frames')\n", + "frame_train_data_dir = os.path.join(frame_data_dir, 'train')\n", + "frame_test_data_dir = os.path.join(frame_data_dir, 'test')\n", + "\n", + "for i in range(8):\n", + " train_count = 0\n", + " for file in os.listdir(os.path.join(frame_train_data_dir, \"C%d\"%i)):\n", + " train_count += 1\n", + " print (\"C%d Training Data Count: \"%i, train_count, flush=True)\n", + " test_count = 0\n", + " for file in os.listdir(os.path.join(frame_test_data_dir, \"C%d\"%i)):\n", + " test_count += 1\n", + " print (\"C%d Testing Data Count: \"%i, test_count, flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Step 2: Create Dataset for Ingesting Image Frames" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class MultiClassVideoFrameDataset(Dataset):\n", + " def __init__(self, root_dir, transform=None, processor=None):\n", + " self.root_dir = root_dir\n", + " self.transform = transform\n", + " self.processor = processor\n", + " self.classes = os.listdir(root_dir) # Get class names from subdirectories\n", + "\n", + " self.frames = []\n", + " self.labels = []\n", + "\n", + " for class_idx, class_name in enumerate(self.classes):\n", + " class_dir = os.path.join(self.root_dir, class_name)\n", + " frame_list = [os.path.join(class_dir, file) for file in os.listdir(class_dir) if file.endswith(('.jpg', '.png', '.jpeg'))]\n", + " self.frames.extend(frame_list)\n", + " self.labels.extend([class_idx] * len(frame_list))\n", + "\n", + " def __len__(self):\n", + " return len(self.frames)\n", + "\n", + " def __getitem__(self, idx):\n", + " frame_path = self.frames[idx]\n", + " image = cv2.imread(frame_path)\n", + " # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB\n", + "\n", + " if self.transform:\n", + " image = self.transform(image)\n", + " \n", + " if self.processor:\n", + " image = self.processor.preprocess(image, return_tensors=\"pt\")\n", + "\n", + " label = self.labels[idx]\n", + "\n", + " return image, label" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "image_processor = ViTImageProcessor(\n", + " \"google/vit-base-patch16-224\",\n", + " do_normalize=True,\n", + " max_size=384,\n", + " pad_to_max_size=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# define some transforms\n", + "transform = transforms.Compose([\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "full_train_dataset = MultiClassVideoFrameDataset(root_dir=frame_train_data_dir, transform=transform, processor=image_processor)\n", + "test_dataset = MultiClassVideoFrameDataset(root_dir=frame_test_data_dir, transform=transform, processor=image_processor)\n", + "\n", + "# Define the percentage of data to use for validation\n", + "validation_split = 0.2 # Adjust this as needed\n", + "\n", + "# Calculate the number of samples for the validation set\n", + "num_samples = len(full_train_dataset)\n", + "num_val_samples = int(validation_split * num_samples)\n", + "num_train_samples = num_samples - num_val_samples\n", + "\n", + "# Create a list of indices for the full dataset\n", + "indices = list(range(num_samples))\n", + "\n", + "# Use random sampling to split the indices into train and validation indices\n", + "val_indices = torch.randperm(num_samples)[:num_val_samples]\n", + "train_indices = list(set(indices) - set(val_indices))\n", + "\n", + "# Create Subset objects for train and validation\n", + "train_dataset = Subset(full_train_dataset, train_indices)\n", + "val_dataset = Subset(full_train_dataset, val_indices)\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)\n", + "test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "591161" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(full_train_dataset) + len(test_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6RAi_buLVX7r" + }, + "source": [ + "# Step 3: Build the GasViT Architecture\n", + "\n", + "### Define the GasNet architecture (GasNet-2 as mentioned in the paper)\n", + "### Implement the model using TensorFlow/Keras" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda')" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "device\n", + "# model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "configs = ViTConfig(\n", + " hidden_dropout_prob=0.5,\n", + " attention_probs_dropout_prob=0.2,\n", + " num_labels=8\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:\n", + "- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([8]) in the model instantiated\n", + "- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([8, 768]) in the model instantiated\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { + "data": { + "text/plain": [ + "ViTForImageClassification(\n", + " (vit): ViTModel(\n", + " (embeddings): ViTEmbeddings(\n", + " (patch_embeddings): ViTPatchEmbeddings(\n", + " (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n", + " )\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (encoder): ViTEncoder(\n", + " (layer): ModuleList(\n", + " (0-11): 12 x ViTLayer(\n", + " (attention): ViTAttention(\n", + " (attention): ViTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.2, inplace=False)\n", + " )\n", + " (output): ViTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): ViTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): ViTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " )\n", + " )\n", + " )\n", + " (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " )\n", + " (classifier): Linear(in_features=768, out_features=8, bias=True)\n", + ")" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Initialize the ViT feature extractor and model\n", + "model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', config=configs, ignore_mismatched_sizes=True)\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Linear(in_features=768, out_features=8, bias=True)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Change classifier layer to have num classes consistent with dataset\n", + "model.classifier.out_features = len(full_train_dataset.classes)\n", + "model.classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['C6', 'C7', 'C2', 'C1', 'C5', 'C4', 'C3', 'C0']" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataloader.dataset.dataset.classes" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, weight=None, num_epochs=10):\n", + " criterion = nn.CrossEntropyLoss(weight=weight) # extendable for multiclass classification as well\n", + " optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", + "\n", + " # can try out lr scheduler later if needed\n", + " # can also try out warmup ratio\n", + "\n", + " for epoch in range(num_epochs):\n", + " model.train()\n", + " for batch_images, batch_labels in tqdm(train_dataloader):\n", + " batch_image_pixels, batch_labels = batch_images.pixel_values.squeeze(1).to(device), batch_labels.to(device)\n", + " optimizer.zero_grad()\n", + " outputs = model(batch_image_pixels).logits\n", + " loss = criterion(outputs, batch_labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " print(f\"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}\", flush=True)\n", + "\n", + " model.eval()\n", + " accuracy = 0\n", + " total_samples = 0\n", + "\n", + " with torch.no_grad():\n", + " for batch_images, batch_labels in tqdm(val_dataloader, leave=False):\n", + " batch_image_pixels, batch_labels = batch_images.pixel_values.squeeze(1).to(device), batch_labels.to(device)\n", + " outputs = model(batch_image_pixels).logits\n", + " _, predicted = torch.max(outputs, 1)\n", + " accuracy += (predicted == batch_labels).sum().item()\n", + " total_samples += batch_labels.size(0)\n", + "\n", + " validation_accuracy = accuracy / total_samples\n", + " print(f\"Validation Accuracy: {validation_accuracy:.4f}\", flush=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Adjust classweights to account for class imbalance" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# Adjust Class weights here\n", + "# class_weight = torch.tensor([1]*8).float().to(device)\n", + "class_weight = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).float().to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "be86647b3fd542fc978726af3e857a5c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/10725 [00:00 1\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m6\u001b[39;49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[16], line 10\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, weight, num_epochs)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n\u001b[1;32m 9\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_images, batch_labels \u001b[38;5;129;01min\u001b[39;00m tqdm(train_dataloader):\n\u001b[1;32m 11\u001b[0m batch_image_pixels, batch_labels \u001b[38;5;241m=\u001b[39m batch_images\u001b[38;5;241m.\u001b[39mpixel_values\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device), batch_labels\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 12\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/tqdm/notebook.py:249\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 248\u001b[0m it \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msuper\u001b[39m(tqdm_notebook, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__iter__\u001b[39m()\n\u001b[0;32m--> 249\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m it:\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 252\u001b[0m \u001b[38;5;66;03m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/tqdm/std.py:1182\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1179\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 631\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 633\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 635\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 636\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 637\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:677\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 675\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 676\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 677\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 678\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 679\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataset.py:298\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(idx, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindices[i] \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m idx]]\n\u001b[0;32m--> 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindices\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\n", + "Cell \u001b[0;32mIn[6], line 29\u001b[0m, in \u001b[0;36mMultiClassVideoFrameDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 26\u001b[0m image \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform(image)\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprocessor:\n\u001b[0;32m---> 29\u001b[0m image \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocessor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpreprocess\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 31\u001b[0m label \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabels[idx]\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m image, label\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/vit/image_processing_vit.py:257\u001b[0m, in \u001b[0;36mViTImageProcessor.preprocess\u001b[0;34m(self, images, do_resize, size, resample, do_rescale, rescale_factor, do_normalize, image_mean, image_std, return_tensors, data_format, input_data_format, **kwargs)\u001b[0m\n\u001b[1;32m 251\u001b[0m images \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrescale(image\u001b[38;5;241m=\u001b[39mimage, scale\u001b[38;5;241m=\u001b[39mrescale_factor, input_data_format\u001b[38;5;241m=\u001b[39minput_data_format)\n\u001b[1;32m 253\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m images\n\u001b[1;32m 254\u001b[0m ]\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m do_normalize:\n\u001b[0;32m--> 257\u001b[0m images \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 258\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnormalize(image\u001b[38;5;241m=\u001b[39mimage, mean\u001b[38;5;241m=\u001b[39mimage_mean, std\u001b[38;5;241m=\u001b[39mimage_std, input_data_format\u001b[38;5;241m=\u001b[39minput_data_format)\n\u001b[1;32m 259\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m images\n\u001b[1;32m 260\u001b[0m ]\n\u001b[1;32m 262\u001b[0m images \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 263\u001b[0m to_channel_dimension_format(image, data_format, input_channel_dim\u001b[38;5;241m=\u001b[39minput_data_format) \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m images\n\u001b[1;32m 264\u001b[0m ]\n\u001b[1;32m 266\u001b[0m data \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m: images}\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/vit/image_processing_vit.py:258\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 251\u001b[0m images \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrescale(image\u001b[38;5;241m=\u001b[39mimage, scale\u001b[38;5;241m=\u001b[39mrescale_factor, input_data_format\u001b[38;5;241m=\u001b[39minput_data_format)\n\u001b[1;32m 253\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m images\n\u001b[1;32m 254\u001b[0m ]\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m do_normalize:\n\u001b[1;32m 257\u001b[0m images \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m--> 258\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnormalize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmean\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimage_mean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstd\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimage_std\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_data_format\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_data_format\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 259\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m images\n\u001b[1;32m 260\u001b[0m ]\n\u001b[1;32m 262\u001b[0m images \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 263\u001b[0m to_channel_dimension_format(image, data_format, input_channel_dim\u001b[38;5;241m=\u001b[39minput_data_format) \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m images\n\u001b[1;32m 264\u001b[0m ]\n\u001b[1;32m 266\u001b[0m data \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m: images}\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/image_processing_utils.py:616\u001b[0m, in \u001b[0;36mBaseImageProcessor.normalize\u001b[0;34m(self, image, mean, std, data_format, input_data_format, **kwargs)\u001b[0m\n\u001b[1;32m 583\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnormalize\u001b[39m(\n\u001b[1;32m 584\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 585\u001b[0m image: np\u001b[38;5;241m.\u001b[39mndarray,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 590\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 591\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m np\u001b[38;5;241m.\u001b[39mndarray:\n\u001b[1;32m 592\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 593\u001b[0m \u001b[38;5;124;03m Normalize an image. image = (image - image_mean) / image_std.\u001b[39;00m\n\u001b[1;32m 594\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 614\u001b[0m \u001b[38;5;124;03m `np.ndarray`: The normalized image.\u001b[39;00m\n\u001b[1;32m 615\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 616\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mnormalize\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 617\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmean\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstd\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_format\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_format\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_data_format\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_data_format\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 618\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/image_transforms.py:398\u001b[0m, in \u001b[0;36mnormalize\u001b[0;34m(image, mean, std, data_format, input_data_format)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 396\u001b[0m image \u001b[38;5;241m=\u001b[39m ((image\u001b[38;5;241m.\u001b[39mT \u001b[38;5;241m-\u001b[39m mean) \u001b[38;5;241m/\u001b[39m std)\u001b[38;5;241m.\u001b[39mT\n\u001b[0;32m--> 398\u001b[0m image \u001b[38;5;241m=\u001b[39m to_channel_dimension_format(image, data_format, input_data_format) \u001b[38;5;28;01mif\u001b[39;00m data_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m image\n\u001b[1;32m 399\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m image\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "train(model, class_weight, 6)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# Specify the file path for saving the model\n", + "model_path = 'vit_model_8.pth'\n", + "\n", + "# Save the model's state_dict to the specified file\n", + "torch.save(model.state_dict(), model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_path = 'vit_model_8.pth'\n", + "model.load_state_dict(torch.load(model_path))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jddooPmpVX7t" + }, + "source": [ + "# Step 4: Evaluate the model on the Test Dataset\n", + "\n", + "### Generate evaluation metrics and plots such as confusion matrix and ROC curves, F1 score, etc." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_v4ikO-tVX7t" + }, + "source": [ + "We are primarily concerned with high false positive rate due to the extreme class imbalance" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "def predict(model):\n", + " model.eval()\n", + " accuracy = 0\n", + " total_samples = 0\n", + " predictions = [] # List to store the predictions\n", + " truth_labels = [] # List to store the truth labels\n", + "\n", + " with torch.no_grad():\n", + " for batch_images, batch_labels in tqdm(test_dataloader, leave=False):\n", + " batch_image_pixels, batch_labels = batch_images.pixel_values.squeeze(1).to(device), batch_labels.to(device)\n", + " outputs = model(batch_image_pixels).logits\n", + " _, predicted = torch.max(outputs, 1)\n", + " accuracy += (predicted == batch_labels).sum().item()\n", + " total_samples += batch_labels.size(0)\n", + "\n", + " predictions.extend(predicted.cpu().numpy())\n", + " truth_labels.extend(batch_labels.cpu().numpy())\n", + "\n", + " validation_accuracy = accuracy / total_samples\n", + " print(f\"Test Accuracy: {validation_accuracy:.4f}\", flush=True)\n", + " return predictions, truth_labels" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7a6c4b08e8e940ee82ee95b44dbf9bf7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/7750 [00:00 807\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstat\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mroot\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mst_mtime\n\u001b[1;32m 808\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlookup\u001b[38;5;241m.\u001b[39mcache_clear()\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/bestlab/anaconda3/envs/angeline/lib/python310.zip'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[22], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m predictions, truth_labels \u001b[38;5;241m=\u001b[39m \u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m predictions, truth_labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(predictions), np\u001b[38;5;241m.\u001b[39marray(truth_labels)\n\u001b[1;32m 3\u001b[0m df_predictions \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(data\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpredictions\u001b[39m\u001b[38;5;124m\"\u001b[39m: predictions, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtruth_labels\u001b[39m\u001b[38;5;124m\"\u001b[39m: truth_labels})\n", + "Cell \u001b[0;32mIn[21], line 9\u001b[0m, in \u001b[0;36mpredict\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m 6\u001b[0m truth_labels \u001b[38;5;241m=\u001b[39m [] \u001b[38;5;66;03m# List to store the truth labels\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_images, batch_labels \u001b[38;5;129;01min\u001b[39;00m tqdm(test_dataloader, leave\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 10\u001b[0m batch_image_pixels, batch_labels \u001b[38;5;241m=\u001b[39m batch_images\u001b[38;5;241m.\u001b[39mpixel_values\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device), batch_labels\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 11\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(batch_image_pixels)\u001b[38;5;241m.\u001b[39mlogits\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/tqdm/notebook.py:249\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 248\u001b[0m it \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msuper\u001b[39m(tqdm_notebook, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__iter__\u001b[39m()\n\u001b[0;32m--> 249\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m it:\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 252\u001b[0m \u001b[38;5;66;03m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/tqdm/std.py:1182\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1179\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 631\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 633\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 635\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 636\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 637\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:677\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 675\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 676\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 677\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 678\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 679\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "Cell \u001b[0;32mIn[6], line 29\u001b[0m, in \u001b[0;36mMultiClassVideoFrameDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 26\u001b[0m image \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform(image)\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprocessor:\n\u001b[0;32m---> 29\u001b[0m image \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocessor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpreprocess\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 31\u001b[0m label \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabels[idx]\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m image, label\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/vit/image_processing_vit.py:232\u001b[0m, in \u001b[0;36mViTImageProcessor.preprocess\u001b[0;34m(self, images, do_resize, size, resample, do_rescale, rescale_factor, do_normalize, image_mean, image_std, return_tensors, data_format, input_data_format, **kwargs)\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRescale factor must be specified if do_rescale is True.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 231\u001b[0m \u001b[38;5;66;03m# All transformations expect numpy arrays.\u001b[39;00m\n\u001b[0;32m--> 232\u001b[0m images \u001b[38;5;241m=\u001b[39m [to_numpy_array(image) \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m images]\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_scaled_image(images[\u001b[38;5;241m0\u001b[39m]) \u001b[38;5;129;01mand\u001b[39;00m do_rescale:\n\u001b[1;32m 235\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning_once(\n\u001b[1;32m 236\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIt looks like you are trying to rescale already rescaled images. If the input\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 238\u001b[0m )\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/vit/image_processing_vit.py:232\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRescale factor must be specified if do_rescale is True.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 231\u001b[0m \u001b[38;5;66;03m# All transformations expect numpy arrays.\u001b[39;00m\n\u001b[0;32m--> 232\u001b[0m images \u001b[38;5;241m=\u001b[39m [\u001b[43mto_numpy_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m images]\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_scaled_image(images[\u001b[38;5;241m0\u001b[39m]) \u001b[38;5;129;01mand\u001b[39;00m do_rescale:\n\u001b[1;32m 235\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning_once(\n\u001b[1;32m 236\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIt looks like you are trying to rescale already rescaled images. If the input\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 238\u001b[0m )\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/image_utils.py:155\u001b[0m, in \u001b[0;36mto_numpy_array\u001b[0;34m(img)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_valid_image(img):\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid image type: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(img)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mis_vision_available\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(img, PIL\u001b[38;5;241m.\u001b[39mImage\u001b[38;5;241m.\u001b[39mImage):\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39marray(img)\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m to_numpy(img)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/utils/import_utils.py:660\u001b[0m, in \u001b[0;36mis_vision_available\u001b[0;34m()\u001b[0m\n\u001b[1;32m 658\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _pil_available:\n\u001b[1;32m 659\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 660\u001b[0m package_version \u001b[38;5;241m=\u001b[39m \u001b[43mimportlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mversion\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mPillow\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 661\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m importlib\u001b[38;5;241m.\u001b[39mmetadata\u001b[38;5;241m.\u001b[39mPackageNotFoundError:\n\u001b[1;32m 662\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/importlib/metadata/__init__.py:996\u001b[0m, in \u001b[0;36mversion\u001b[0;34m(distribution_name)\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mversion\u001b[39m(distribution_name):\n\u001b[1;32m 990\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Get the version string for the named package.\u001b[39;00m\n\u001b[1;32m 991\u001b[0m \n\u001b[1;32m 992\u001b[0m \u001b[38;5;124;03m :param distribution_name: The name of the distribution package to query.\u001b[39;00m\n\u001b[1;32m 993\u001b[0m \u001b[38;5;124;03m :return: The version string for the package as defined in the package's\u001b[39;00m\n\u001b[1;32m 994\u001b[0m \u001b[38;5;124;03m \"Version\" metadata key.\u001b[39;00m\n\u001b[1;32m 995\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 996\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdistribution\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdistribution_name\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mversion\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/importlib/metadata/__init__.py:969\u001b[0m, in \u001b[0;36mdistribution\u001b[0;34m(distribution_name)\u001b[0m\n\u001b[1;32m 963\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdistribution\u001b[39m(distribution_name):\n\u001b[1;32m 964\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Get the ``Distribution`` instance for the named package.\u001b[39;00m\n\u001b[1;32m 965\u001b[0m \n\u001b[1;32m 966\u001b[0m \u001b[38;5;124;03m :param distribution_name: The name of the distribution package as a string.\u001b[39;00m\n\u001b[1;32m 967\u001b[0m \u001b[38;5;124;03m :return: A ``Distribution`` instance (or subclass thereof).\u001b[39;00m\n\u001b[1;32m 968\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 969\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mDistribution\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_name\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdistribution_name\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/importlib/metadata/__init__.py:544\u001b[0m, in \u001b[0;36mDistribution.from_name\u001b[0;34m(cls, name)\u001b[0m\n\u001b[1;32m 542\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m resolver \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_discover_resolvers():\n\u001b[1;32m 543\u001b[0m dists \u001b[38;5;241m=\u001b[39m resolver(DistributionFinder\u001b[38;5;241m.\u001b[39mContext(name\u001b[38;5;241m=\u001b[39mname))\n\u001b[0;32m--> 544\u001b[0m dist \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43miter\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdists\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 545\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dist \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 546\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m dist\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/importlib/metadata/__init__.py:904\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 901\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Find metadata directories in paths heuristically.\"\"\"\u001b[39;00m\n\u001b[1;32m 902\u001b[0m prepared \u001b[38;5;241m=\u001b[39m Prepared(name)\n\u001b[1;32m 903\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m itertools\u001b[38;5;241m.\u001b[39mchain\u001b[38;5;241m.\u001b[39mfrom_iterable(\n\u001b[0;32m--> 904\u001b[0m \u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msearch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprepared\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m path \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(FastPath, paths)\n\u001b[1;32m 905\u001b[0m )\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/importlib/metadata/__init__.py:802\u001b[0m, in \u001b[0;36mFastPath.search\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msearch\u001b[39m(\u001b[38;5;28mself\u001b[39m, name):\n\u001b[0;32m--> 802\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlookup(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmtime\u001b[49m)\u001b[38;5;241m.\u001b[39msearch(name)\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/importlib/metadata/__init__.py:807\u001b[0m, in \u001b[0;36mFastPath.mtime\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 805\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmtime\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 806\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m suppress(\u001b[38;5;167;01mOSError\u001b[39;00m):\n\u001b[0;32m--> 807\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstat\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mroot\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mst_mtime\n\u001b[1;32m 808\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlookup\u001b[38;5;241m.\u001b[39mcache_clear()\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "predictions, truth_labels = predict(model)\n", + "predictions, truth_labels = np.array(predictions), np.array(truth_labels)\n", + "df_predictions = pd.DataFrame(data={\"predictions\": predictions, \"truth_labels\": truth_labels})\n", + "df_predictions.to_csv(\"vit_8_preds.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'truth_labels' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[23], line 6\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ConfusionMatrixDisplay\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m#source: https://vitalflux.com/python-draw-confusion-matrix-matplotlib/\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m conf_matrix \u001b[38;5;241m=\u001b[39m confusion_matrix(y_true\u001b[38;5;241m=\u001b[39m\u001b[43mtruth_labels\u001b[49m\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;28mint\u001b[39m), y_pred\u001b[38;5;241m=\u001b[39mpredictions)\n\u001b[1;32m 7\u001b[0m fig, ax \u001b[38;5;241m=\u001b[39m plt\u001b[38;5;241m.\u001b[39msubplots(figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m7.5\u001b[39m, \u001b[38;5;241m7.5\u001b[39m))\n\u001b[1;32m 8\u001b[0m ax\u001b[38;5;241m.\u001b[39mmatshow(conf_matrix, cmap\u001b[38;5;241m=\u001b[39mplt\u001b[38;5;241m.\u001b[39mcm\u001b[38;5;241m.\u001b[39mBlues, alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.3\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'truth_labels' is not defined" + ] + } + ], + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "from sklearn.metrics import ConfusionMatrixDisplay\n", + "\n", + "#source: https://vitalflux.com/python-draw-confusion-matrix-matplotlib/\n", + "\n", + "conf_matrix = confusion_matrix(y_true=truth_labels.astype(int), y_pred=predictions)\n", + "fig, ax = plt.subplots(figsize=(7.5, 7.5))\n", + "ax.matshow(conf_matrix, cmap=plt.cm.Blues, alpha=0.3)\n", + "for i in range(conf_matrix.shape[0]):\n", + " for j in range(conf_matrix.shape[1]):\n", + " ax.text(x=j, y=i,s=conf_matrix[i, j], va='center', ha='center', size='xx-large')\n", + " \n", + "plt.xlabel('Predictions', fontsize=18)\n", + "plt.ylabel('Actuals', fontsize=18)\n", + "plt.title(f'Confusion Matrix (0: Leak, 1:Nonleak)', fontsize=18)\n", + "plt.savefig(\"vit_8_confusion.png\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'truth_labels' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[24], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m accuracy_score,f1_score\n\u001b[0;32m----> 2\u001b[0m accuracy \u001b[38;5;241m=\u001b[39m accuracy_score(\u001b[43mtruth_labels\u001b[49m, predictions)\n", + "\u001b[0;31mNameError\u001b[0m: name 'truth_labels' is not defined" + ] + } + ], + "source": [ + "from sklearn.metrics import accuracy_score,f1_score\n", + "accuracy = accuracy_score(truth_labels, predictions)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'accuracy' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maccuracy: \u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[43maccuracy\u001b[49m, flush\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'accuracy' is not defined" + ] + } + ], + "source": [ + "print(\"accuracy: \", accuracy, flush=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'truth_labels' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[26], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m f1 \u001b[38;5;241m=\u001b[39m f1_score(\u001b[43mtruth_labels\u001b[49m, predictions)\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mf1 score: \u001b[39m\u001b[38;5;124m\"\u001b[39m, f1, flush\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'truth_labels' is not defined" + ] + } + ], + "source": [ + "f1 = f1_score(truth_labels, predictions)\n", + "print(\"f1 score: \", f1, flush=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'truth_labels' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[27], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m conf_matrix \u001b[38;5;241m=\u001b[39m confusion_matrix(\u001b[43mtruth_labels\u001b[49m, predictions)\n\u001b[1;32m 2\u001b[0m per_class_accuracy \u001b[38;5;241m=\u001b[39m conf_matrix\u001b[38;5;241m.\u001b[39mdiagonal() \u001b[38;5;241m/\u001b[39m conf_matrix\u001b[38;5;241m.\u001b[39msum(axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'truth_labels' is not defined" + ] + } + ], + "source": [ + "conf_matrix = confusion_matrix(truth_labels, predictions)\n", + "per_class_accuracy = conf_matrix.diagonal() / conf_matrix.sum(axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'per_class_accuracy' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[28], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m label, acc \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m([\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mC6\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mC7\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mC2\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mC1\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mC5\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mC4\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mC3\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mC0\u001b[39m\u001b[38;5;124m'\u001b[39m], \u001b[43mper_class_accuracy\u001b[49m):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mClass \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlabel\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m Accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00macc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, flush\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'per_class_accuracy' is not defined" + ] + } + ], + "source": [ + "for label, acc in zip(['C6', 'C7', 'C2', 'C1', 'C5', 'C4', 'C3', 'C0'], per_class_accuracy):\n", + " print(f\"Class '{label}' Accuracy: {acc:.4f}\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OCvUK3ryVX7t" + }, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "vscode": { + "interpreter": { + "hash": "92148a12c102ce31dad7b6dc4c1f8747c19091aa07dd8a934fcd2c33582f9a61" + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/MethaneModel/Gas-ViT-8-classes.py b/MethaneModel/Gas-ViT-8-classes.py new file mode 100644 index 0000000..85ef029 --- /dev/null +++ b/MethaneModel/Gas-ViT-8-classes.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python +# coding: utf-8 + +# ### This is a Pytorch version of the work, for easier time working with ViT + +# # Step 1: Load and Preprocess the Dataset +# +# ### Load the GasVid dataset +# ### Preprocess the data +# ### Split the dataset into training and test sets + +# In[2]: + + +get_ipython().run_cell_magic('capture', '', '%pip install --upgrade pip\n%pip install -r ../requirements.txt\n') + + +# In[3]: + + +# if apple and want MPS acceleration do this +# %%capture +# %pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu + + +# In[4]: + + +# Imports + +import os +import numpy as np +import matplotlib.pyplot as plt +import cv2 +import torch +import torch.nn as nn +from torchvision import transforms +from transformers import ViTForImageClassification, ViTImageProcessor, ViTConfig +from torch.utils.data import DataLoader, Dataset, Subset +from tqdm.notebook import tqdm +import pandas as pd + +from sklearn.metrics import RocCurveDisplay, roc_curve, ConfusionMatrixDisplay, confusion_matrix + +from sklearn.mixture import GaussianMixture + + +# ### Setting up Directories + +# In[5]: + + +# get generic path to directory +dir_path = os.path.dirname(os.path.realpath("__file__")) + +# get all raw video data directories +data_dir = os.path.join(dir_path, 'data') + +train_data_dir = os.path.join(data_dir, 'train') +test_data_dir = os.path.join(data_dir, 'test') + +frame_data_dir = os.path.join(dir_path, 'background_sub_movingavg8_frames') +frame_train_data_dir = os.path.join(frame_data_dir, 'train') +frame_test_data_dir = os.path.join(frame_data_dir, 'test') + +for i in range(8): + train_count = 0 + for file in os.listdir(os.path.join(frame_train_data_dir, "C%d"%i)): + train_count += 1 + print ("C%d Training Data Count: "%i, train_count, flush=True) + test_count = 0 + for file in os.listdir(os.path.join(frame_test_data_dir, "C%d"%i)): + test_count += 1 + print ("C%d Testing Data Count: "%i, test_count, flush=True) + + +# # Step 2: Create Dataset for Ingesting Image Frames + +# In[6]: + + +class MultiClassVideoFrameDataset(Dataset): + def __init__(self, root_dir, transform=None, processor=None): + self.root_dir = root_dir + self.transform = transform + self.processor = processor + self.classes = os.listdir(root_dir) # Get class names from subdirectories + + self.frames = [] + self.labels = [] + + for class_idx, class_name in enumerate(self.classes): + class_dir = os.path.join(self.root_dir, class_name) + frame_list = [os.path.join(class_dir, file) for file in os.listdir(class_dir) if file.endswith(('.jpg', '.png', '.jpeg'))] + self.frames.extend(frame_list) + self.labels.extend([class_idx] * len(frame_list)) + + def __len__(self): + return len(self.frames) + + def __getitem__(self, idx): + frame_path = self.frames[idx] + image = cv2.imread(frame_path) + # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB + + if self.transform: + image = self.transform(image) + + if self.processor: + image = self.processor.preprocess(image, return_tensors="pt") + + label = self.labels[idx] + + return image, label + + +# In[7]: + + +image_processor = ViTImageProcessor( + "google/vit-base-patch16-224", + do_normalize=True, + max_size=384, + pad_to_max_size=True +) + + +# In[8]: + + +# define some transforms +transform = transforms.Compose([ +]) + + +# In[9]: + + +full_train_dataset = MultiClassVideoFrameDataset(root_dir=frame_train_data_dir, transform=transform, processor=image_processor) +test_dataset = MultiClassVideoFrameDataset(root_dir=frame_test_data_dir, transform=transform, processor=image_processor) + +# Define the percentage of data to use for validation +validation_split = 0.2 # Adjust this as needed + +# Calculate the number of samples for the validation set +num_samples = len(full_train_dataset) +num_val_samples = int(validation_split * num_samples) +num_train_samples = num_samples - num_val_samples + +# Create a list of indices for the full dataset +indices = list(range(num_samples)) + +# Use random sampling to split the indices into train and validation indices +val_indices = torch.randperm(num_samples)[:num_val_samples] +train_indices = list(set(indices) - set(val_indices)) + +# Create Subset objects for train and validation +train_dataset = Subset(full_train_dataset, train_indices) +val_dataset = Subset(full_train_dataset, val_indices) + +train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) +val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False) +test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) + + +# In[10]: + + +len(full_train_dataset) + len(test_dataset) + + +# # Step 3: Build the GasViT Architecture +# +# ### Define the GasNet architecture (GasNet-2 as mentioned in the paper) +# ### Implement the model using TensorFlow/Keras + +# In[11]: + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device +# model.to(device) + + +# In[12]: + + +configs = ViTConfig( + hidden_dropout_prob=0.5, + attention_probs_dropout_prob=0.2, + num_labels=8 +) + + +# In[13]: + + +# Initialize the ViT feature extractor and model +model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', config=configs, ignore_mismatched_sizes=True) +model.to(device) + + +# In[14]: + + +# Change classifier layer to have num classes consistent with dataset +model.classifier.out_features = len(full_train_dataset.classes) +model.classifier + + +# In[15]: + + +train_dataloader.dataset.dataset.classes + + +# In[16]: + + +def train(model, weight=None, num_epochs=10): + criterion = nn.CrossEntropyLoss(weight=weight) # extendable for multiclass classification as well + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + + # can try out lr scheduler later if needed + # can also try out warmup ratio + + for epoch in range(num_epochs): + model.train() + for batch_images, batch_labels in tqdm(train_dataloader): + batch_image_pixels, batch_labels = batch_images.pixel_values.squeeze(1).to(device), batch_labels.to(device) + optimizer.zero_grad() + outputs = model(batch_image_pixels).logits + loss = criterion(outputs, batch_labels) + loss.backward() + optimizer.step() + + print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}", flush=True) + + model.eval() + accuracy = 0 + total_samples = 0 + + with torch.no_grad(): + for batch_images, batch_labels in tqdm(val_dataloader, leave=False): + batch_image_pixels, batch_labels = batch_images.pixel_values.squeeze(1).to(device), batch_labels.to(device) + outputs = model(batch_image_pixels).logits + _, predicted = torch.max(outputs, 1) + accuracy += (predicted == batch_labels).sum().item() + total_samples += batch_labels.size(0) + + validation_accuracy = accuracy / total_samples + print(f"Validation Accuracy: {validation_accuracy:.4f}", flush=True) + + +# ### Adjust classweights to account for class imbalance + +# In[17]: + + +# Adjust Class weights here +# class_weight = torch.tensor([1]*8).float().to(device) +class_weight = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).float().to(device) + + +# In[18]: + + +train(model, class_weight, 6) + + +# In[19]: + + +# Specify the file path for saving the model +model_path = 'vit_model_8.pth' + +# Save the model's state_dict to the specified file +torch.save(model.state_dict(), model_path) + + +# In[20]: + + +model_path = 'vit_model_8.pth' +model.load_state_dict(torch.load(model_path)) + + +# # Step 4: Evaluate the model on the Test Dataset +# +# ### Generate evaluation metrics and plots such as confusion matrix and ROC curves, F1 score, etc. + +# We are primarily concerned with high false positive rate due to the extreme class imbalance + +# In[21]: + + +def predict(model): + model.eval() + accuracy = 0 + total_samples = 0 + predictions = [] # List to store the predictions + truth_labels = [] # List to store the truth labels + + with torch.no_grad(): + for batch_images, batch_labels in tqdm(test_dataloader, leave=False): + batch_image_pixels, batch_labels = batch_images.pixel_values.squeeze(1).to(device), batch_labels.to(device) + outputs = model(batch_image_pixels).logits + _, predicted = torch.max(outputs, 1) + accuracy += (predicted == batch_labels).sum().item() + total_samples += batch_labels.size(0) + + predictions.extend(predicted.cpu().numpy()) + truth_labels.extend(batch_labels.cpu().numpy()) + + validation_accuracy = accuracy / total_samples + print(f"Test Accuracy: {validation_accuracy:.4f}", flush=True) + return predictions, truth_labels + + +# In[22]: + + +predictions, truth_labels = predict(model) +predictions, truth_labels = np.array(predictions), np.array(truth_labels) +df_predictions = pd.DataFrame(data={"predictions": predictions, "truth_labels": truth_labels}) +df_predictions.to_csv("vit_8_preds.csv", index=False) + + +# In[23]: + + +from sklearn.metrics import confusion_matrix +from sklearn.metrics import ConfusionMatrixDisplay + +#source: https://vitalflux.com/python-draw-confusion-matrix-matplotlib/ + +conf_matrix = confusion_matrix(y_true=truth_labels.astype(int), y_pred=predictions) +fig, ax = plt.subplots(figsize=(7.5, 7.5)) +ax.matshow(conf_matrix, cmap=plt.cm.Blues, alpha=0.3) +for i in range(conf_matrix.shape[0]): + for j in range(conf_matrix.shape[1]): + ax.text(x=j, y=i,s=conf_matrix[i, j], va='center', ha='center', size='xx-large') + +plt.xlabel('Predictions', fontsize=18) +plt.ylabel('Actuals', fontsize=18) +plt.title(f'Confusion Matrix (0: Leak, 1:Nonleak)', fontsize=18) +plt.savefig("vit_8_confusion.png") +plt.show() + + +# In[24]: + + +from sklearn.metrics import accuracy_score,f1_score +accuracy = accuracy_score(truth_labels, predictions) + + +# In[25]: + + +print("accuracy: ", accuracy, flush=True) + + +# In[26]: + + +f1 = f1_score(truth_labels, predictions) +print("f1 score: ", f1, flush=True) + + +# In[27]: + + +conf_matrix = confusion_matrix(truth_labels, predictions) +per_class_accuracy = conf_matrix.diagonal() / conf_matrix.sum(axis=1) + + +# In[28]: + + +for label, acc in zip(['C6', 'C7', 'C2', 'C1', 'C5', 'C4', 'C3', 'C0'], per_class_accuracy): + print(f"Class '{label}' Accuracy: {acc:.4f}", flush=True) + + +# + +# diff --git a/MethaneModel/background_subtraction_avg.py b/MethaneModel/background_subtraction_avg.py new file mode 100644 index 0000000..d1270ab --- /dev/null +++ b/MethaneModel/background_subtraction_avg.py @@ -0,0 +1,97 @@ +# Imports +import os +import numpy as np +import cv2 +from sklearn.mixture import GaussianMixture +from tqdm import tqdm +import multiprocessing +import uuid +import matplotlib.pyplot as plt +import pandas as pd +import re + +# Helper Functions + +def calc_median(frames): + median_frame = np.median(frames, axis=0).astype(dtype=np.uint8) + return median_frame + +def doMovingAverageBGS(image, prev_frames): + median_img = calc_median(prev_frames) + image = cv2.absdiff(image, median_img) + return image + +def calc_avg(frames): + average_frame = np.mean(frames).astype(dtype=np.uint8) + return average_frame + +dir_path = os.path.dirname(os.path.realpath("__file__")) +data_dir = os.path.join(dir_path, 'data') +train_data_dir = os.path.join(data_dir, 'train') +test_data_dir = os.path.join(data_dir, 'test') + +# # get all raw video data directories +# vid_path = '/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/data/train/MOV_2559.mp4' +background_path = '/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg' +os.makedirs(background_path, exist_ok=True) + +def get_frames(vid_path, out_path, med_count): + + before_path = os.path.join(out_path, 'before') + after_path = os.path.join(out_path, 'after') + median_path = os.path.join(out_path, 'median') + + print("Before path" + before_path, flush = True) + print("After path" + after_path, flush = True) + print("Median path" + median_path, flush = True) + + os.makedirs(before_path, exist_ok=True) + os.makedirs(after_path, exist_ok=True) + os.makedirs(median_path, exist_ok=True) + + cap = cv2.VideoCapture(vid_path) + + cap.set(cv2.CAP_PROP_POS_MSEC, 0) + + num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + print("Num frames: %d" % num_frames, flush = True) + print("Frames per second: %d" % fps, flush = True) + + background = [] + times = [] + + for i in range(med_count): + success, image = cap.read() + background.append(image) + + cap.set(cv2.CAP_PROP_POS_MSEC, 0) + + for i in range(num_frames): + success, image = cap.read() + time = cap.get(cv2.CAP_PROP_POS_MSEC) + times.append(time) + cv2.imwrite(os.path.join(before_path, 'test%d.jpg' % i), image) + median_background = np.median(background, axis = 0) + cv2.imwrite(os.path.join(median_path, 'test%d.jpg' % i), median_background) + removed_image = image-median_background + cv2.imwrite(os.path.join(after_path, 'test%d.jpg' % i), removed_image) + if (i>=med_count): + background.pop(0) + background.append(image) + + cap.release() + cv2.destroyAllWindows() + + df_time = pd.DataFrame(times) + df_time.to_csv(out_path + '/df_time.csv') + +for file in os.listdir(train_data_dir): + vid_path = os.path.join(train_data_dir, file) + vid_id = int(re.findall("_(\d{4}).mp4",os.path.basename(vid_path))[0]) + print("Extracting vid_id: %d" % vid_id, flush = True) + output_path = os.path.join(background_path, str(vid_id)) + os.makedirs(output_path, exist_ok=True) + get_frames(vid_path, output_path, 210) + +print("Completed", flush = True) \ No newline at end of file diff --git a/MethaneModel/background_subtraction_stationary.py b/MethaneModel/background_subtraction_stationary.py new file mode 100644 index 0000000..37654ca --- /dev/null +++ b/MethaneModel/background_subtraction_stationary.py @@ -0,0 +1,92 @@ +# Imports +import os +import numpy as np +import cv2 +from sklearn.mixture import GaussianMixture +from tqdm import tqdm +import multiprocessing +import uuid +import matplotlib.pyplot as plt +import pandas as pd +import re + +# Helper Functions + +def calc_median(frames): + median_frame = np.median(frames, axis=0).astype(dtype=np.uint8) + return median_frame + +def doMovingAverageBGS(image, prev_frames): + median_img = calc_median(prev_frames) + image = cv2.absdiff(image, median_img) + return image + +def calc_avg(frames): + average_frame = np.mean(frames).astype(dtype=np.uint8) + return average_frame + +dir_path = os.path.dirname(os.path.realpath("__file__")) +data_dir = os.path.join(dir_path, 'data') +train_data_dir = os.path.join(data_dir, 'train') +test_data_dir = os.path.join(data_dir, 'test') + +# # get all raw video data directories +# vid_path = '/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/data/train/MOV_2559.mp4' +background_path = '/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg' +os.makedirs(background_path, exist_ok=True) + +def get_frames(vid_path, out_path, med_count): + + before_path = os.path.join(out_path, 'before') + after_path = os.path.join(out_path, 'after') + + print("Before path" + before_path, flush = True) + print("After path" + after_path, flush = True) + + os.makedirs(before_path, exist_ok=True) + os.makedirs(after_path, exist_ok=True) + + cap = cv2.VideoCapture(vid_path) + + cap.set(cv2.CAP_PROP_POS_MSEC, 0) + + num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) +print("Num frames: %d" % num_frames, flush = True) + print("Frames per second: %d" % fps, flush = True) + + background = [] + times = [] + + for i in range(med_count): + success, image = cap.read() + background.append(image) + + median_background = np.median(background, axis = 0) + cv2.imwrite(os.path.join(out_path, 'median.jpg'), median_background) + + cap.set(cv2.CAP_PROP_POS_MSEC, 0) + + for i in range(num_frames): + success, image = cap.read() + time = cap.get(cv2.CAP_PROP_POS_MSEC) + times.append(time) + cv2.imwrite(os.path.join(before_path, 'test%d.jpg' % i), image) + removed_image = image-median_background + cv2.imwrite(os.path.join(after_path, 'test%d.jpg' % i), removed_image) + + cap.release() + cv2.destroyAllWindows() + + df_time = pd.DataFrame(times) + df_time.to_csv(out_path + '/df_time.csv') + +for file in os.listdir(train_data_dir): + vid_path = os.path.join(train_data_dir, file) + vid_id = int(re.findall("_(\d{4}).mp4",os.path.basename(vid_path))[0]) + print("Extracting vid_id: %d" % vid_id, flush = True) + output_path = os.path.join(background_path, str(vid_id)) + os.makedirs(output_path, exist_ok=True) + get_frames(vid_path, output_path, 210) + +print("Completed", flush = True) \ No newline at end of file diff --git a/MethaneModel/frame_extraction_complete.ipynb b/MethaneModel/frame_extraction_complete.ipynb new file mode 100644 index 0000000..2252f66 --- /dev/null +++ b/MethaneModel/frame_extraction_complete.ipynb @@ -0,0 +1,624 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Extracting Frames" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "import os\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import cv2\n", + "import torch\n", + "import torch.nn as nn\n", + "import pandas as pd\n", + "import re\n", + "import shutil\n", + "from torchvision import transforms\n", + "from transformers import ViTForImageClassification, ViTImageProcessor, ViTConfig\n", + "from torch.utils.data import DataLoader, Dataset, Subset\n", + "from tqdm.notebook import tqdm\n", + "from sklearn.metrics import RocCurveDisplay, roc_curve, ConfusionMatrixDisplay, confusion_matrix\n", + "from sklearn.mixture import GaussianMixture" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Moving Average Background Subtraction" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "# Helper Functions\n", + "def calc_median(frames):\n", + " median_frame = np.median(frames, axis=0).astype(dtype=np.uint8)\n", + " return median_frame\n", + "\n", + "def doMovingAverageBGS(image, prev_frames):\n", + " median_img = calc_median(prev_frames)\n", + " image = cv2.absdiff(image, median_img)\n", + " return image" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "dir_path = os.path.dirname(os.path.realpath(\"__file__\"))\n", + "\n", + "# get paths to training and testing data\n", + "data_dir = os.path.join(dir_path, 'data')\n", + "train_data_dir = os.path.join(data_dir, 'train')\n", + "test_data_dir = os.path.join(data_dir, 'test')\n", + "\n", + "# get path to folder where extracted frames are stored\n", + "background_path = os.path.join(dir_path, 'background_sub_testing_movingavg2')\n", + "os.makedirs(background_path, exist_ok=True)\n", + "\n", + "#path to where binary frames are stored\n", + "frame_data_dir = os.path.join(dir_path, \"background_sub_movingavg_frames2\")\n", + "frame_train_data_dir = os.path.join(frame_data_dir, 'train')\n", + "frame_test_data_dir = os.path.join(frame_data_dir, 'test')\n", + "\n", + "frame_train_data_dir_nonleak = os.path.join(frame_train_data_dir, 'Nonleaks')\n", + "frame_train_data_dir_leak = os.path.join(frame_train_data_dir, 'Leaks')\n", + "frame_test_data_dir_nonleak = os.path.join(frame_test_data_dir, 'Nonleaks')\n", + "frame_test_data_dir_leak = os.path.join(frame_test_data_dir, 'Leaks')\n", + "\n", + "os.makedirs(frame_data_dir, exist_ok=True)\n", + "os.makedirs(frame_train_data_dir, exist_ok=True)\n", + "os.makedirs(frame_test_data_dir, exist_ok=True)\n", + "\n", + "os.makedirs(frame_train_data_dir_nonleak, exist_ok=True)\n", + "os.makedirs(frame_train_data_dir_leak, exist_ok=True)\n", + "os.makedirs(frame_test_data_dir_nonleak, exist_ok=True)\n", + "os.makedirs(frame_test_data_dir_leak, exist_ok=True)\n", + "\n", + "# get folder to put 8 classes in\n", + "classes_folder = os.path.join(dir_path, 'background_sub_movingavg8_frames2')\n", + "os.makedirs(classes_folder, exist_ok=True)\n", + "\n", + "classes_train_folder = os.path.join(classes_folder, 'train')\n", + "os.makedirs(classes_train_folder, exist_ok = True)\n", + "\n", + "classes_test_folder = os.path.join(classes_folder, 'test')\n", + "os.makedirs(classes_test_folder, exist_ok = True)\n", + "\n", + "for i in range(8):\n", + " os.makedirs(os.path.join(classes_train_folder, 'C' + str(i)), exist_ok=True)\n", + " os.makedirs(os.path.join(classes_test_folder, 'C' + str(i)), exist_ok=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "def get_frames(vid_path, out_path, med_count):\n", + "\n", + " before_path = os.path.join(out_path, 'before')\n", + " after_path = os.path.join(out_path, 'after')\n", + " median_path = os.path.join(out_path, 'median')\n", + "\n", + " print(\"Before path\" + before_path, flush = True)\n", + " print(\"After path\" + after_path, flush = True)\n", + " print(\"Median path\" + median_path, flush = True)\n", + "\n", + " os.makedirs(before_path, exist_ok=True)\n", + " os.makedirs(after_path, exist_ok=True)\n", + " os.makedirs(median_path, exist_ok=True)\n", + "\n", + " cap = cv2.VideoCapture(vid_path)\n", + "\n", + " cap.set(cv2.CAP_PROP_POS_MSEC, 0)\n", + " success = True\n", + "\n", + " num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", + " fps = cap.get(cv2.CAP_PROP_FPS) \n", + " print(\"Num frames: %d\" % num_frames, flush = True)\n", + " print(\"Frames per second: %d\" % fps, flush = True)\n", + "\n", + " prev_imgs = []\n", + " times = []\n", + "\n", + " for i in range(med_count):\n", + " success, image = cap.read()\n", + " image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n", + " prev_imgs.append(image) \n", + "\n", + " cap.set(cv2.CAP_PROP_POS_MSEC, 0)\n", + "\n", + " for i in range(num_frames):\n", + " success, image = cap.read()\n", + " image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n", + " time = cap.get(cv2.CAP_PROP_POS_MSEC)\n", + " times.append(time)\n", + " cv2.imwrite(os.path.join(before_path, 'test%d.jpg' % i), image)\n", + " median_background = np.median(prev_imgs, axis = 0)\n", + " cv2.imwrite(os.path.join(median_path, 'test%d.jpg' % i), median_background)\n", + " removed_image = doMovingAverageBGS(image, prev_imgs)\n", + " cv2.imwrite(os.path.join(after_path, 'test%d.jpg' % i), removed_image)\n", + " prev_imgs.pop(0)\n", + " prev_imgs.append(image)\n", + "\n", + " cap.release()\n", + " cv2.destroyAllWindows()\n", + "\n", + " df_time = pd.DataFrame(times)\n", + " df_time.to_csv(out_path + '/df_time.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting vid_id: 2580\n", + "Before path/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg2/2580/before\n", + "After path/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg2/2580/after\n", + "Median path/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg2/2580/median\n", + "Num frames: 22095\n", + "Frames per second: 14\n" + ] + } + ], + "source": [ + "for file in os.listdir(train_data_dir):\n", + " vid_path = os.path.join(train_data_dir, file)\n", + " vid_id = int(re.findall(\"_(\\d{4}).mp4\",os.path.basename(vid_path))[0])\n", + " print(\"Extracting vid_id: %d\" % vid_id, flush = True)\n", + " output_path = os.path.join(background_path, str(vid_id))\n", + " os.makedirs(output_path, exist_ok=True)\n", + " get_frames(vid_path, output_path, 210)\n", + "\n", + "for file in os.listdir(test_data_dir):\n", + " vid_path = os.path.join(test_data_dir, file)\n", + " vid_id = int(re.findall(\"_(\\d{4}).mp4\",os.path.basename(vid_path))[0])\n", + " print(\"Extracting vid_id: %d\" % vid_id, flush = True)\n", + " output_path = os.path.join(background_path, str(vid_id))\n", + " os.makedirs(output_path, exist_ok=True)\n", + " get_frames(vid_path, output_path, 210)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Moving frames to binary classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(31, 4)" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ranges = pd.read_csv('/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/GasVid_Ranges_Seconds.csv')\n", + "ranges = ranges.set_index('Video No.')\n", + "ranges.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def copy_frames_nonleak(vid_id, train_test_path):\n", + " col_names = [\"Index\", \"Times\"]\n", + " stat_times = pd.read_csv(os.path.join(background_path, '%d/df_time.csv' % vid_id), names = col_names, header = None)\n", + " stat_times = stat_times.dropna()\n", + "\n", + " nonleak_start = ranges.loc[vid_id,'Nonleak Range Start (s)'] * 1000\n", + " nonleak_end = ranges.loc[vid_id,'Nonleak Range End (s)'] * 1000\n", + "\n", + " start_bool = stat_times['Times'] >= nonleak_start\n", + " end_bool = stat_times['Times'] <= nonleak_end\n", + "\n", + " valid_index = stat_times[start_bool & end_bool]\n", + " valid_index = valid_index.astype({'Index': 'int32'})\n", + "\n", + " folder_before = '/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg/%d/after' %vid_id\n", + " valid_index['Before Filename'] = valid_index[\"Index\"].apply(lambda x: os.path.join(folder_before, 'test%d.jpg'%x))\n", + " valid_index['After Filename'] = valid_index[\"Index\"].apply(lambda x: os.path.join(train_test_path, 'vid%dtest%d.jpg'%(vid_id,x)))\n", + " for i in range(valid_index.shape[0]):\n", + " before = valid_index.iloc[i,]['Before Filename']\n", + " after = valid_index.iloc[i,]['After Filename']\n", + " shutil.copy(before, after)\n", + "\n", + "def copy_frames_leak(vid_id, train_test_path):\n", + " col_names = [\"Index\", \"Times\"]\n", + " stat_times = pd.read_csv('/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg/%d/df_time.csv' % vid_id, names = col_names, header = None)\n", + " stat_times = stat_times.dropna()\n", + "\n", + " leak_start = ranges.loc[vid_id,'Leak Range Start (s)'] * 1000\n", + " leak_end = ranges.loc[vid_id,'Leak Range End (s)'] * 1000\n", + "\n", + " start_bool = stat_times['Times'] >= leak_start\n", + " end_bool = stat_times['Times'] <= leak_end\n", + "\n", + " valid_index = stat_times[start_bool & end_bool]\n", + " valid_index = valid_index.astype({'Index': 'int32'})\n", + "\n", + " folder_before = '/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg/%d/after' %vid_id\n", + " valid_index['Before Filename'] = valid_index[\"Index\"].apply(lambda x: os.path.join(folder_before, 'test%d.jpg'%x))\n", + " valid_index['After Filename'] = valid_index[\"Index\"].apply(lambda x: os.path.join(train_test_path, 'vid%dtest%d.jpg'%(vid_id,x)))\n", + " for i in range(valid_index.shape[0]):\n", + " before = valid_index.iloc[i,]['Before Filename']\n", + " after = valid_index.iloc[i,]['After Filename']\n", + " shutil.copy(before, after)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Begin moving training data\n", + "Moving vid_id: 2580\n" + ] + }, + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg2/2580/df_time.csv'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[36], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m vid_id \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(re\u001b[38;5;241m.\u001b[39mfindall(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_(\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124md\u001b[39m\u001b[38;5;132;01m{4}\u001b[39;00m\u001b[38;5;124m).mp4\u001b[39m\u001b[38;5;124m\"\u001b[39m,os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mbasename(vid_path))[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMoving vid_id: \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m vid_id, flush \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 6\u001b[0m \u001b[43mcopy_frames_nonleak\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvid_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mframe_train_data_dir_nonleak\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m copy_frames_leak(vid_id, frame_train_data_dir_leak)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBegin moving testing data\u001b[39m\u001b[38;5;124m\"\u001b[39m, flush \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "Cell \u001b[0;32mIn[35], line 3\u001b[0m, in \u001b[0;36mcopy_frames_nonleak\u001b[0;34m(vid_id, train_test_path)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcopy_frames_nonleak\u001b[39m(vid_id, train_test_path):\n\u001b[1;32m 2\u001b[0m col_names \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIndex\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTimes\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m----> 3\u001b[0m stat_times \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbackground_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;132;43;01m%d\u001b[39;49;00m\u001b[38;5;124;43m/df_time.csv\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m%\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mvid_id\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnames\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mcol_names\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheader\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m stat_times \u001b[38;5;241m=\u001b[39m stat_times\u001b[38;5;241m.\u001b[39mdropna()\n\u001b[1;32m 6\u001b[0m nonleak_start \u001b[38;5;241m=\u001b[39m ranges\u001b[38;5;241m.\u001b[39mloc[vid_id,\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNonleak Range Start (s)\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m1000\u001b[39m\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/site-packages/pandas/io/parsers/readers.py:1024\u001b[0m, in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)\u001b[0m\n\u001b[1;32m 1011\u001b[0m kwds_defaults \u001b[38;5;241m=\u001b[39m _refine_defaults_read(\n\u001b[1;32m 1012\u001b[0m dialect,\n\u001b[1;32m 1013\u001b[0m delimiter,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1020\u001b[0m dtype_backend\u001b[38;5;241m=\u001b[39mdtype_backend,\n\u001b[1;32m 1021\u001b[0m )\n\u001b[1;32m 1022\u001b[0m kwds\u001b[38;5;241m.\u001b[39mupdate(kwds_defaults)\n\u001b[0;32m-> 1024\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_read\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/site-packages/pandas/io/parsers/readers.py:618\u001b[0m, in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 615\u001b[0m _validate_names(kwds\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[1;32m 617\u001b[0m \u001b[38;5;66;03m# Create the parser.\u001b[39;00m\n\u001b[0;32m--> 618\u001b[0m parser \u001b[38;5;241m=\u001b[39m \u001b[43mTextFileReader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 620\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m chunksize \u001b[38;5;129;01mor\u001b[39;00m iterator:\n\u001b[1;32m 621\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parser\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/site-packages/pandas/io/parsers/readers.py:1618\u001b[0m, in \u001b[0;36mTextFileReader.__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 1615\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptions[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhas_index_names\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwds[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhas_index_names\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 1617\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles: IOHandles \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1618\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_engine \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_engine\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mengine\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/site-packages/pandas/io/parsers/readers.py:1878\u001b[0m, in \u001b[0;36mTextFileReader._make_engine\u001b[0;34m(self, f, engine)\u001b[0m\n\u001b[1;32m 1876\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m mode:\n\u001b[1;32m 1877\u001b[0m mode \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m-> 1878\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles \u001b[38;5;241m=\u001b[39m \u001b[43mget_handle\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1879\u001b[0m \u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1880\u001b[0m \u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mencoding\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[43m \u001b[49m\u001b[43mcompression\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcompression\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1883\u001b[0m \u001b[43m \u001b[49m\u001b[43mmemory_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmemory_map\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1884\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_text\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_text\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1885\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mencoding_errors\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstrict\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1886\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstorage_options\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1887\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1888\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1889\u001b[0m f \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles\u001b[38;5;241m.\u001b[39mhandle\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/site-packages/pandas/io/common.py:873\u001b[0m, in \u001b[0;36mget_handle\u001b[0;34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[0m\n\u001b[1;32m 868\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(handle, \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 869\u001b[0m \u001b[38;5;66;03m# Check whether the filename is to be opened in binary mode.\u001b[39;00m\n\u001b[1;32m 870\u001b[0m \u001b[38;5;66;03m# Binary mode does not support 'encoding' and 'newline'.\u001b[39;00m\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ioargs\u001b[38;5;241m.\u001b[39mencoding \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m ioargs\u001b[38;5;241m.\u001b[39mmode:\n\u001b[1;32m 872\u001b[0m \u001b[38;5;66;03m# Encoding\u001b[39;00m\n\u001b[0;32m--> 873\u001b[0m handle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 874\u001b[0m \u001b[43m \u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 875\u001b[0m \u001b[43m \u001b[49m\u001b[43mioargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 876\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mioargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 877\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 878\u001b[0m \u001b[43m \u001b[49m\u001b[43mnewline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 879\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 880\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 881\u001b[0m \u001b[38;5;66;03m# Binary mode\u001b[39;00m\n\u001b[1;32m 882\u001b[0m handle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mopen\u001b[39m(handle, ioargs\u001b[38;5;241m.\u001b[39mmode)\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg2/2580/df_time.csv'" + ] + } + ], + "source": [ + "print(\"Begin moving training data\", flush = True)\n", + "for file in os.listdir(train_data_dir):\n", + " vid_path = os.path.join(train_data_dir, file)\n", + " vid_id = int(re.findall(\"_(\\d{4}).mp4\",os.path.basename(vid_path))[0])\n", + " print(\"Moving vid_id: %d\" % vid_id, flush = True)\n", + " copy_frames_nonleak(vid_id, frame_train_data_dir_nonleak)\n", + " copy_frames_leak(vid_id, frame_train_data_dir_leak)\n", + "\n", + "print(\"Begin moving testing data\", flush = True)\n", + "for file in os.listdir(test_data_dir):\n", + " vid_path = os.path.join(test_data_dir, file)\n", + " vid_id = int(re.findall(\"_(\\d{4}).mp4\",os.path.basename(vid_path))[0])\n", + " print(\"Moving vid_id: %d\" % vid_id, flush = True)\n", + " copy_frames_nonleak(vid_id, frame_test_data_dir_nonleak)\n", + " copy_frames_leak(vid_id, frame_test_data_dir_leak)\n", + "print(\"Completed\", flush = True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Moving frames to 8 classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
C0(S)C0(E)C1(S)C1(E)C2(S)C2(E)C3(S)C3(E)C4(S)C4(E)C5(S)C5(E)C6(S)C6(E)C7(S)C7(E)
Video No.
25643119121137139155157173175191193110911111127112911451
25594420422438440456458474476492494411041124128413041464
25603219221237239255257273275291293210921112127212921452
25613419421437439455457473475491493410941114127412941454
25623919921937939955957973975991993910991119127912991459
\n", + "
" + ], + "text/plain": [ + " C0(S) C0(E) C1(S) C1(E) C2(S) C2(E) C3(S) C3(E) C4(S) \\\n", + "Video No. \n", + "2564 31 191 211 371 391 551 571 731 751 \n", + "2559 44 204 224 384 404 564 584 744 764 \n", + "2560 32 192 212 372 392 552 572 732 752 \n", + "2561 34 194 214 374 394 554 574 734 754 \n", + "2562 39 199 219 379 399 559 579 739 759 \n", + "\n", + " C4(E) C5(S) C5(E) C6(S) C6(E) C7(S) C7(E) \n", + "Video No. \n", + "2564 911 931 1091 1111 1271 1291 1451 \n", + "2559 924 944 1104 1124 1284 1304 1464 \n", + "2560 912 932 1092 1112 1272 1292 1452 \n", + "2561 914 934 1094 1114 1274 1294 1454 \n", + "2562 919 939 1099 1119 1279 1299 1459 " + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ranges = pd.read_csv(\"/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/GasVid_Ranges_all.csv\", index_col = \"Video No.\")\n", + "ranges.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_8_classes(vid_id, train_test_path):\n", + " col_names = [\"Index\", \"Times\"]\n", + " frame_times = pd.read_csv('/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg/%d/df_time.csv' % vid_id, names = col_names, header = None)\n", + " frame_times = frame_times.dropna()\n", + " folder_vid = '/home/bestlab/Desktop/Squishy-Methane-URAP-New/AngelineLee/MethaneModel/background_sub_testing_movingavg/%d/after' %vid_id\n", + "\n", + " for i in range(8):\n", + " start_time = ranges.loc[vid_id, 'C%d(S)'%i] * 1000\n", + " end_time = ranges.loc[vid_id, 'C%d(E)'%i] * 1000\n", + "\n", + " start_bool = frame_times['Times'] >= start_time\n", + " end_bool = frame_times['Times'] <= end_time\n", + " valid_index = frame_times[start_bool & end_bool]\n", + " valid_index = valid_index.astype({'Index': 'int32'})\n", + " \n", + " valid_index['Before Filename'] = valid_index[\"Index\"].apply(lambda x: os.path.join(folder_vid, 'test%d.jpg'%x))\n", + " class_folder = os.path.join(train_test_path, 'C' + str(i))\n", + " valid_index['After Filename'] = valid_index[\"Index\"].apply(lambda x: os.path.join(class_folder, 'vid%dtest%d.jpg'%(vid_id,x)))\n", + " for i in range(valid_index.shape[0]):\n", + " before = valid_index.iloc[i,]['Before Filename']\n", + " after = valid_index.iloc[i,]['After Filename']\n", + " shutil.copy(before, after)\n", + " print('Finished moving video %d'%vid_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Moving vid_id: 2580\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[40], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m extract_8_classes(vid_id, classes_test_folder)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m2\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m----> 8\u001b[0m \u001b[43mextract_8_classes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvid_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclasses_train_folder\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[38], line 22\u001b[0m, in \u001b[0;36mextract_8_classes\u001b[0;34m(vid_id, train_test_path)\u001b[0m\n\u001b[1;32m 20\u001b[0m before \u001b[38;5;241m=\u001b[39m valid_index\u001b[38;5;241m.\u001b[39miloc[i,][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mBefore Filename\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 21\u001b[0m after \u001b[38;5;241m=\u001b[39m valid_index\u001b[38;5;241m.\u001b[39miloc[i,][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAfter Filename\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m---> 22\u001b[0m \u001b[43mshutil\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbefore\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mafter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFinished moving video \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m%\u001b[39mvid_id)\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/shutil.py:417\u001b[0m, in \u001b[0;36mcopy\u001b[0;34m(src, dst, follow_symlinks)\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(dst):\n\u001b[1;32m 416\u001b[0m dst \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(dst, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mbasename(src))\n\u001b[0;32m--> 417\u001b[0m \u001b[43mcopyfile\u001b[49m\u001b[43m(\u001b[49m\u001b[43msrc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdst\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfollow_symlinks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfollow_symlinks\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 418\u001b[0m copymode(src, dst, follow_symlinks\u001b[38;5;241m=\u001b[39mfollow_symlinks)\n\u001b[1;32m 419\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m dst\n", + "File \u001b[0;32m~/anaconda3/envs/angeline/lib/python3.10/shutil.py:233\u001b[0m, in \u001b[0;36mcopyfile\u001b[0;34m(src, dst, follow_symlinks)\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Copy data from src to dst in the most efficient way possible.\u001b[39;00m\n\u001b[1;32m 226\u001b[0m \n\u001b[1;32m 227\u001b[0m \u001b[38;5;124;03mIf follow_symlinks is not set and src is a symbolic link, a new\u001b[39;00m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;124;03msymlink will be created instead of copying the file it points to.\u001b[39;00m\n\u001b[1;32m 229\u001b[0m \n\u001b[1;32m 230\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 231\u001b[0m sys\u001b[38;5;241m.\u001b[39maudit(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshutil.copyfile\u001b[39m\u001b[38;5;124m\"\u001b[39m, src, dst)\n\u001b[0;32m--> 233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43m_samefile\u001b[49m\u001b[43m(\u001b[49m\u001b[43msrc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdst\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m SameFileError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{!r}\u001b[39;00m\u001b[38;5;124m and \u001b[39m\u001b[38;5;132;01m{!r}\u001b[39;00m\u001b[38;5;124m are the same file\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(src, dst))\n\u001b[1;32m 236\u001b[0m file_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "for file in os.listdir(background_path):\n", + " vid_path = os.path.join(background_path, file)\n", + " vid_id = int(file)\n", + " print(\"Moving vid_id: %d\" % vid_id, flush = True)\n", + " if file[0] == '1':\n", + " extract_8_classes(vid_id, classes_test_folder)\n", + " if file[0] == '2':\n", + " extract_8_classes(vid_id, classes_train_folder)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "angeline", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}