diff --git a/.gitignore b/.gitignore index 30db9cd..e99604f 100644 --- a/.gitignore +++ b/.gitignore @@ -229,12 +229,13 @@ settings.json # Block all configs besides the example config whoot_model_training/configs !whoot_model_training/configs/config.yml - -# Block demos +*.csv +*.ipynb +*.json demos/ *.ipynb # Block predictions predictions/* *.pkl -*.arrow \ No newline at end of file +*.arrow diff --git a/data_downloader/downloader_demo.ipynb b/data_downloader/downloader_demo.ipynb new file mode 100644 index 0000000..487dd4e --- /dev/null +++ b/data_downloader/downloader_demo.ipynb @@ -0,0 +1,224 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3842a3a9", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import requests\n", + "from xc import XenoCantoDownloader\n", + "from dotenv import load_dotenv\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "# Load environment variables from the .env file\n", + "load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa77bd28", + "metadata": {}, + "outputs": [], + "source": [ + "xcd = XenoCantoDownloader(api_key=os.environ[\"XC_API_KEY\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ee1304b", + "metadata": {}, + "outputs": [], + "source": [ + "query = xcd.build_query()\n", + "res = xcd.get_page(query)\n", + "res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fed0106", + "metadata": {}, + "outputs": [], + "source": [ + "data = xcd(query=\"box:32.485,-117.582,33.482,-115.228\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a43fcef6", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "with open(\"xc_meta.json\", mode=\"w\") as f:\n", + " json.dump(data, f, indent=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0942a027", + "metadata": {}, + "outputs": [], + "source": [ + "req = requests.get(data[0][\"recordings\"][0][\"file\"])\n", + "req" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d62a7a6", + "metadata": {}, + "outputs": [], + "source": [ + "data[0][\"recordings\"][0][\"file\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcc63419", + "metadata": {}, + "outputs": [], + "source": [ + "req.content" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30af4509", + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "import os\n", + "from pathlib import Path\n", + "from multiprocessing.pool import ThreadPool\n", + "\n", + "# https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests\n", + "def download_file(url, local_filename, dry_run=False):\n", + " if os.path.exists(local_filename):\n", + " return local_filename\n", + "\n", + " try:\n", + " with requests.get(url, stream=True) as r:\n", + " with open(local_filename, 'wb') as f:\n", + " if not dry_run:\n", + " shutil.copyfileobj(r.raw, f)\n", + " else:\n", + " print(local_filename)\n", + "\n", + " return local_filename\n", + " except IOError as e:\n", + " print(e, flush=True)\n", + " return None\n", + "\n", + "def download_files(xcd, data, parent_folder=\"data/xeno-canto\", workers = 4):\n", + " def prep_download(args):\n", + " url = args[0]\n", + " file_path = args[1]\n", + " return download_file(url, file_path)\n", + "\n", + " os.makedirs(parent_folder, exist_ok=True)\n", + "\n", + " if \"recordings\" in data[0]:\n", + " data = xcd.concat_recording_data(data) \n", + " download_data = [\n", + " (recording[\"file\"], Path(parent_folder) / Path(recording[\"file-name\"]))\n", + " for recording in data\n", + " ]\n", + " pool = ThreadPool(workers)\n", + " results = pool.imap_unordered(prep_download, download_data) \n", + " pool.close()\n", + " return results\n", + "\n", + "download_files(xcd, data)" + ] + }, + { + "cell_type": "markdown", + "id": "ea02004c", + "metadata": {}, + "source": [ + "# Study" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5bf99a36", + "metadata": {}, + "outputs": [], + "source": [ + "recordings = xcd.concat_recording_data(data)\n", + "df = pd.DataFrame(recordings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4e26ec8", + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(df[\"en\"].value_counts())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9d7dc37", + "metadata": {}, + "outputs": [], + "source": [ + "plt.ylabel(\"Number of Species\")\n", + "plt.xlabel(\"Number of Indivuals Per Species\")\n", + "plt.title(\"Do We Have a Few-shot Learning Problem for XC in Southern California?\")\n", + "df[\"en\"].value_counts().hist()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c04ef6f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "whoot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/data_downloader/xc.py b/data_downloader/xc.py new file mode 100644 index 0000000..971f454 --- /dev/null +++ b/data_downloader/xc.py @@ -0,0 +1,123 @@ +"""Xeno-Canto Data Metadata Downloader and Search Module.""" +import os +import urllib.parse +import json +import requests + + +class XenoCantoDownloader(): + """Handler for Xeno-Canto API. + + Note: Requires an API key from env var "XC_API_KEY". + Third version of the Xeno-Canto API is used here. + """ + def __init__(self, api_key=None): + """Creates the Xeno-Canto Downloader. + + Args: + api_key (str): API key for Xeno-Canto API. + If None, looks for env var "XC_API_KEY" + """ + self.endpoint_url = "https://xeno-canto.org/api/3/recordings" + self.api_key = os.environ["XC_API_KEY"] if api_key is None else api_key + assert self.api_key is not None, \ + "API KEY MISSING: Put API key in Environment Var!" + + def __call__(self, + query=None, + loc=None, + ): + r"""Download XC data. + + Initally, this was intended to be used to build queries + So more args were planned (hence loc). In practice, it was easier + to build queries by hand ¯\_(ツ)_/¯ + + You can pull the query you want from the url on the website if you + are manually searching for thigns there. Its the same syntax. + + Also is useful for debugging issues there + + Args: + query (str/None): Search query string see XC Search Tags + loc (str/None): Location string for search query + """ + if query is None: + query = self.build_query( + loc=loc, + ) + + page_datas = [] + page_data = self.get_page(query, page=1) + page_datas.append(page_data) + + # Get rest of data! + for i in range(2, page_data["numPages"] + 1): + page_data = self.get_page(query, page=i) + page_datas.append(page_data) + + return page_datas + + def concat_recording_data(self, page_datas): + """Concatinate recording data from multiple pages. + + Args: + page_datas (list): list of page data dicts + """ + new_page_data = [] + for page_data in page_datas: + new_page_data = new_page_data + page_data["recordings"] + return new_page_data + + def build_query( + self, + loc="San Diego, California, United States of America", + # box=None, + ): + """Builds a query string for Xeno-Canto API. + + See https://xeno-canto.org/help/search + Args: + loc (str): Location string for search query + """ + search_tags = "" + if loc is not None: + search_tags += f"loc:\"{loc}\"+" + # Remove trailing + + return search_tags[:-1] + + def get_page(self, query, page=1): + """Get a page of results from Xeno-Canto API. + + Args: + query (str): Search query string see XC Search Tags + page (int): Page number to retrieve + """ + res = requests.get( + self.endpoint_url + "?" + urllib.parse.urlencode({ + "query": query, + "key": self.api_key, + "page": page + }), + timeout=100 + ) + if res.status_code == 200: + return json.loads(res.text) + + return {} + # def download_files(self, data): + # if type(data) == dict: + # data = self.concat_recording_data(self, data) + # for recording in data: + # requests + + +if __name__ == "__main__": + # parser = argparse.ArgumentParser( + # description='Input Directory Path' + # ) + # parser.add_argument('meta', type=str, + # help='Path to metadata csv') + # args = parser.parse_args() + xcd = XenoCantoDownloader() + print(xcd()) diff --git a/data_downloader/xc_aux_downloader.py b/data_downloader/xc_aux_downloader.py new file mode 100644 index 0000000..cba00cc --- /dev/null +++ b/data_downloader/xc_aux_downloader.py @@ -0,0 +1,119 @@ +"""Downloads auxiliary Xeno-Canto data and audio files. + +Relies on output from data_downloader/xc.py +Create a .env file with XC api-key +`XC_API_KEY=your_api_key_here` +Then call directly with `python xc_aux_downloader.py` +""" + +import shutil +import os +import json +import itertools +from pathlib import Path +from multiprocessing.pool import ThreadPool +from dotenv import load_dotenv +import pandas as pd +import tqdm +import requests +from xc import XenoCantoDownloader + + +# https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests +def download_file(url, local_filename, dry_run=False): + """Download a file from a url to a local file. + + Args: + url (str): url to download file from + local_filename (str): path to local file to save to + dry_run (bool): if True, do not actually download file + Returns: + local_filename (str): path to local file or None if failed + """ + if os.path.exists(local_filename): + return local_filename + + try: + with requests.get(url, stream=True, timeout=1000) as r: + with open(local_filename, 'wb') as f: + if not dry_run: + shutil.copyfileobj(r.raw, f) + else: + print("Pretend download of", local_filename) + + return local_filename + except IOError as e: + print(e, flush=True) + return None + + +def download_files( + xcd: XenoCantoDownloader, + data: list, + parent_folder: str = "data/xeno-canto_aux", + workers: int = 4 +): + """Download all the files collected by the Xeno-Canto downloader. + + Args: + xcd (XenoCantoDownloader): the Xeno-Canto downloader object + Allows for preprocessing of recording metadata + data (list): list of recording data dicts + parent_folder (str): path to folder to store audio files + workers (int): number of parallel download workers + Tune down if hitting rate limits + Returns: + results (list): list of downloaded file paths + """ + def prep_download(args): + url = args[0] + file_path = args[1] + return download_file(url, file_path) + + os.makedirs(parent_folder, exist_ok=True) + + if "recordings" in data[0]: + data = xcd.concat_recording_data(data) + download_data = [ + (recording["file"], Path(parent_folder) / Path(recording["file-name"])) + for recording in data + ] + pool = ThreadPool(workers) + results = pool.imap_unordered(prep_download, download_data) + pool.close() + return results + + +def main(): + """Script to download auxiliary Xeno-Canto data and audio files.""" + # Load environment variables from the .env file + load_dotenv() + + xcd = XenoCantoDownloader(api_key=os.environ["XC_API_KEY"]) + + with open("data/xc_meta.json", mode="r", encoding="utf-8") as f: + data = json.load(f) + + species = { + recording["en"] for page in data for recording in page["recordings"] + } + + data = [] + for specie in tqdm.tqdm(list(species)): + data.append(xcd(query=f'en:"{specie}"')) + + data = list(itertools.chain.from_iterable(data)) + + with open("xc_meta_aux.json", mode="w", encoding="utf-8") as f: + json.dump(data, f, indent=4) + results = download_files(xcd, data) + print("Done downloading files, num downloaded:", len(results)) + + recordings = xcd.concat_recording_data(data) + df = pd.DataFrame(recordings) + + print("Metadata has shape:", df.shape) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index fa1780d..2f57094 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,14 @@ requires-python = ">= 3.10.0, < 3.13.0" dependencies = [ "librosa>=0.10.2.post1", "numba==0.61.0", + "nvitop>=1.5.3", "pandas>=2.3.0", "pydub>=0.25.1", + "python-dotenv>=1.1.1", "pyyaml>=6.0.2", "scikit-learn>=1.7.0", + "soundfile>=0.13.1", + "torchaudio>=2.8.0", "tqdm>=4.67.1", ] @@ -27,10 +31,12 @@ version = {attr = "whoot.__version__"} [project.optional-dependencies] cpu = [ "torch>=2.7.0", + "torchaudio>=2.8.0", "torchvision>=0.22.0", ] cu128 = [ "torch>=2.7.0", + "torchaudio>=2.8.0", "torchvision>=0.22.0", ] model-training = [ @@ -40,22 +46,25 @@ model-training = [ "comet-ml>=3.43.2", ] +perch = [ + "perch-hoplite>=0.1.0", + "tensorflow-hub>=0.16.1", + "tensorflow[and-cuda]>=2.20.0", +] + notebooks = [ "ipykernel>=6.29.5", "ipywidgets>=8.1.6", - "matplotlib>=3.10.5", + "matplotlib>=3.10.6", "seaborn>=0.13.2", ] -birdnet = [ - "birdnet>=0.1.7", -] [packages.index] cu128 = "https://download.pytorch.org/whl/cu128" [tool.setuptools] -packages = ["make_model", "assess_birdnet", "whoot_model_training"] +packages = ["make_model", "assess_birdnet", "whoot_model_training", "data_downloader"] [tool.uv.sources] pyha-analyzer = { git = "https://github.com/UCSD-E4E/pyha-analyzer-2.0.git", branch = "support_whoot" } diff --git a/test.py b/test.py new file mode 100644 index 0000000..21c7a5f --- /dev/null +++ b/test.py @@ -0,0 +1,79 @@ +# # %% +# %load_ext autoreload +# %autoreload 1 + +# %% + +from import WaveformInputPreprocessor +from whoot_model_training.whoot_model_training.models import HFInput, HFModel, HFModelConfig +from whoot_model_training.whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments +from whoot_model_training.whoot_model_training.data_extractor import xc_extractor +from whoot_model_training.whoot_model_training import CometMLLoggerSupplement + + +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# %% +ds = xc_extractor( + XC_dataset_json_path="xc_meta_aux.json", + parent_path="/mnt/acoustics/san_diego_xc_aux/xeno-canto" +) + + + +model = HFModel(HFModelConfig(num_classes=ds.get_number_species())) + +# %% +# %% + +input_wrapper = HFInput() + +train_preprocessor = WaveformInputPreprocessor( + input_wrapper, duration=3 +) + +preprocessor = WaveformInputPreprocessor( + input_wrapper, duration=3 +) + +ds["train"].set_transform(train_preprocessor) +ds["valid"].set_transform(preprocessor) +ds["test"].set_transform(preprocessor) + +print(ds.get_class_labels()) + +# run_name = "fewshot_test_birdmae" +# subproject_name = "fewshot_test" +# dataset_name = "san_diego_xc_aux_09_2025" + +# training_args = WhootTrainingArguments( +# run_name=run_name, +# subproject_name=subproject_name, +# dataset_name=dataset_name, +# ) + +# # COMMON OPTIONAL ARGS +# training_args.num_train_epochs = 100 +# training_args.eval_steps = 2000 +# training_args.per_device_train_batch_size = 16 +# training_args.per_device_eval_batch_size = 16 +# training_args.dataloader_num_workers = 16 +# training_args.run_name = run_name +# training_args.learning_rate = 0.01 +# training_args.save_strategy="steps", # Save at the end of each epoch +# training_args.save_total_limit=2 # Keep only the last 2 checkpoints + +# trainer = WhootTrainer( +# model=model, +# dataset=ds, +# training_args=training_args, +# logger=CometMLLoggerSupplement( +# augmentations=None, +# name=training_args.run_name +# ), +# ) + +# trainer.train() +# model.save_pretrained("model_checkpoints/fewshot_test_birdmae") + diff --git a/timm_check.ipynb b/timm_check.ipynb new file mode 100644 index 0000000..d4b2c02 --- /dev/null +++ b/timm_check.ipynb @@ -0,0 +1,3682 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 63, + "id": "4877adf6", + "metadata": {}, + "outputs": [], + "source": [ + "import timm" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "cd16456e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['aimv2_1b_patch14_224',\n", + " 'aimv2_1b_patch14_336',\n", + " 'aimv2_1b_patch14_448',\n", + " 'aimv2_3b_patch14_224',\n", + " 'aimv2_3b_patch14_336',\n", + " 'aimv2_3b_patch14_448',\n", + " 'aimv2_huge_patch14_224',\n", + " 'aimv2_huge_patch14_336',\n", + " 'aimv2_huge_patch14_448',\n", + " 'aimv2_large_patch14_224',\n", + " 'aimv2_large_patch14_336',\n", + " 'aimv2_large_patch14_448',\n", + " 'bat_resnext26ts',\n", + " 'beit3_base_patch16_224',\n", + " 'beit3_giant_patch14_224',\n", + " 'beit3_giant_patch14_336',\n", + " 'beit3_large_patch16_224',\n", + " 'beit_base_patch16_224',\n", + " 'beit_base_patch16_384',\n", + " 'beit_large_patch16_224',\n", + " 'beit_large_patch16_384',\n", + " 'beit_large_patch16_512',\n", + " 'beitv2_base_patch16_224',\n", + " 'beitv2_large_patch16_224',\n", + " 'botnet26t_256',\n", + " 'botnet50ts_256',\n", + " 'caformer_b36',\n", + " 'caformer_m36',\n", + " 'caformer_s18',\n", + " 'caformer_s36',\n", + " 'cait_m36_384',\n", + " 'cait_m48_448',\n", + " 'cait_s24_224',\n", + " 'cait_s24_384',\n", + " 'cait_s36_384',\n", + " 'cait_xs24_384',\n", + " 'cait_xxs24_224',\n", + " 'cait_xxs24_384',\n", + " 'cait_xxs36_224',\n", + " 'cait_xxs36_384',\n", + " 'coat_lite_medium',\n", + " 'coat_lite_medium_384',\n", + " 'coat_lite_mini',\n", + " 'coat_lite_small',\n", + " 'coat_lite_tiny',\n", + " 'coat_mini',\n", + " 'coat_small',\n", + " 'coat_tiny',\n", + " 'coatnet_0_224',\n", + " 'coatnet_0_rw_224',\n", + " 'coatnet_1_224',\n", + " 'coatnet_1_rw_224',\n", + " 'coatnet_2_224',\n", + " 'coatnet_2_rw_224',\n", + " 'coatnet_3_224',\n", + " 'coatnet_3_rw_224',\n", + " 'coatnet_4_224',\n", + " 'coatnet_5_224',\n", + " 'coatnet_bn_0_rw_224',\n", + " 'coatnet_nano_cc_224',\n", + " 'coatnet_nano_rw_224',\n", + " 'coatnet_pico_rw_224',\n", + " 'coatnet_rmlp_0_rw_224',\n", + " 'coatnet_rmlp_1_rw2_224',\n", + " 'coatnet_rmlp_1_rw_224',\n", + " 'coatnet_rmlp_2_rw_224',\n", + " 'coatnet_rmlp_2_rw_384',\n", + " 'coatnet_rmlp_3_rw_224',\n", + " 'coatnet_rmlp_nano_rw_224',\n", + " 'coatnext_nano_rw_224',\n", + " 'convformer_b36',\n", + " 'convformer_m36',\n", + " 'convformer_s18',\n", + " 'convformer_s36',\n", + " 'convit_base',\n", + " 'convit_small',\n", + " 'convit_tiny',\n", + " 'convmixer_768_32',\n", + " 'convmixer_1024_20_ks9_p14',\n", + " 'convmixer_1536_20',\n", + " 'convnext_atto',\n", + " 'convnext_atto_ols',\n", + " 'convnext_atto_rms',\n", + " 'convnext_base',\n", + " 'convnext_femto',\n", + " 'convnext_femto_ols',\n", + " 'convnext_large',\n", + " 'convnext_large_mlp',\n", + " 'convnext_nano',\n", + " 'convnext_nano_ols',\n", + " 'convnext_pico',\n", + " 'convnext_pico_ols',\n", + " 'convnext_small',\n", + " 'convnext_tiny',\n", + " 'convnext_tiny_hnf',\n", + " 'convnext_xlarge',\n", + " 'convnext_xxlarge',\n", + " 'convnext_zepto_rms',\n", + " 'convnext_zepto_rms_ols',\n", + " 'convnextv2_atto',\n", + " 'convnextv2_base',\n", + " 'convnextv2_femto',\n", + " 'convnextv2_huge',\n", + " 'convnextv2_large',\n", + " 'convnextv2_nano',\n", + " 'convnextv2_pico',\n", + " 'convnextv2_small',\n", + " 'convnextv2_tiny',\n", + " 'crossvit_9_240',\n", + " 'crossvit_9_dagger_240',\n", + " 'crossvit_15_240',\n", + " 'crossvit_15_dagger_240',\n", + " 'crossvit_15_dagger_408',\n", + " 'crossvit_18_240',\n", + " 'crossvit_18_dagger_240',\n", + " 'crossvit_18_dagger_408',\n", + " 'crossvit_base_240',\n", + " 'crossvit_small_240',\n", + " 'crossvit_tiny_240',\n", + " 'cs3darknet_focus_l',\n", + " 'cs3darknet_focus_m',\n", + " 'cs3darknet_focus_s',\n", + " 'cs3darknet_focus_x',\n", + " 'cs3darknet_l',\n", + " 'cs3darknet_m',\n", + " 'cs3darknet_s',\n", + " 'cs3darknet_x',\n", + " 'cs3edgenet_x',\n", + " 'cs3se_edgenet_x',\n", + " 'cs3sedarknet_l',\n", + " 'cs3sedarknet_x',\n", + " 'cs3sedarknet_xdw',\n", + " 'cspdarknet53',\n", + " 'cspresnet50',\n", + " 'cspresnet50d',\n", + " 'cspresnet50w',\n", + " 'cspresnext50',\n", + " 'darknet17',\n", + " 'darknet21',\n", + " 'darknet53',\n", + " 'darknetaa53',\n", + " 'davit_base',\n", + " 'davit_base_fl',\n", + " 'davit_giant',\n", + " 'davit_huge',\n", + " 'davit_huge_fl',\n", + " 'davit_large',\n", + " 'davit_small',\n", + " 'davit_tiny',\n", + " 'deit3_base_patch16_224',\n", + " 'deit3_base_patch16_384',\n", + " 'deit3_huge_patch14_224',\n", + " 'deit3_large_patch16_224',\n", + " 'deit3_large_patch16_384',\n", + " 'deit3_medium_patch16_224',\n", + " 'deit3_small_patch16_224',\n", + " 'deit3_small_patch16_384',\n", + " 'deit_base_distilled_patch16_224',\n", + " 'deit_base_distilled_patch16_384',\n", + " 'deit_base_patch16_224',\n", + " 'deit_base_patch16_384',\n", + " 'deit_small_distilled_patch16_224',\n", + " 'deit_small_patch16_224',\n", + " 'deit_tiny_distilled_patch16_224',\n", + " 'deit_tiny_patch16_224',\n", + " 'densenet121',\n", + " 'densenet161',\n", + " 'densenet169',\n", + " 'densenet201',\n", + " 'densenet264d',\n", + " 'densenetblur121d',\n", + " 'dla34',\n", + " 'dla46_c',\n", + " 'dla46x_c',\n", + " 'dla60',\n", + " 'dla60_res2net',\n", + " 'dla60_res2next',\n", + " 'dla60x',\n", + " 'dla60x_c',\n", + " 'dla102',\n", + " 'dla102x',\n", + " 'dla102x2',\n", + " 'dla169',\n", + " 'dm_nfnet_f0',\n", + " 'dm_nfnet_f1',\n", + " 'dm_nfnet_f2',\n", + " 'dm_nfnet_f3',\n", + " 'dm_nfnet_f4',\n", + " 'dm_nfnet_f5',\n", + " 'dm_nfnet_f6',\n", + " 'dpn48b',\n", + " 'dpn68',\n", + " 'dpn68b',\n", + " 'dpn92',\n", + " 'dpn98',\n", + " 'dpn107',\n", + " 'dpn131',\n", + " 'eca_botnext26ts_256',\n", + " 'eca_halonext26ts',\n", + " 'eca_nfnet_l0',\n", + " 'eca_nfnet_l1',\n", + " 'eca_nfnet_l2',\n", + " 'eca_nfnet_l3',\n", + " 'eca_resnet33ts',\n", + " 'eca_resnext26ts',\n", + " 'eca_vovnet39b',\n", + " 'ecaresnet26t',\n", + " 'ecaresnet50d',\n", + " 'ecaresnet50d_pruned',\n", + " 'ecaresnet50t',\n", + " 'ecaresnet101d',\n", + " 'ecaresnet101d_pruned',\n", + " 'ecaresnet200d',\n", + " 'ecaresnet269d',\n", + " 'ecaresnetlight',\n", + " 'ecaresnext26t_32x4d',\n", + " 'ecaresnext50t_32x4d',\n", + " 'edgenext_base',\n", + " 'edgenext_small',\n", + " 'edgenext_small_rw',\n", + " 'edgenext_x_small',\n", + " 'edgenext_xx_small',\n", + " 'efficientformer_l1',\n", + " 'efficientformer_l3',\n", + " 'efficientformer_l7',\n", + " 'efficientformerv2_l',\n", + " 'efficientformerv2_s0',\n", + " 'efficientformerv2_s1',\n", + " 'efficientformerv2_s2',\n", + " 'efficientnet_b0',\n", + " 'efficientnet_b0_g8_gn',\n", + " 'efficientnet_b0_g16_evos',\n", + " 'efficientnet_b0_gn',\n", + " 'efficientnet_b1',\n", + " 'efficientnet_b1_pruned',\n", + " 'efficientnet_b2',\n", + " 'efficientnet_b2_pruned',\n", + " 'efficientnet_b3',\n", + " 'efficientnet_b3_g8_gn',\n", + " 'efficientnet_b3_gn',\n", + " 'efficientnet_b3_pruned',\n", + " 'efficientnet_b4',\n", + " 'efficientnet_b5',\n", + " 'efficientnet_b6',\n", + " 'efficientnet_b7',\n", + " 'efficientnet_b8',\n", + " 'efficientnet_blur_b0',\n", + " 'efficientnet_cc_b0_4e',\n", + " 'efficientnet_cc_b0_8e',\n", + " 'efficientnet_cc_b1_8e',\n", + " 'efficientnet_el',\n", + " 'efficientnet_el_pruned',\n", + " 'efficientnet_em',\n", + " 'efficientnet_es',\n", + " 'efficientnet_es_pruned',\n", + " 'efficientnet_h_b5',\n", + " 'efficientnet_l2',\n", + " 'efficientnet_lite0',\n", + " 'efficientnet_lite1',\n", + " 'efficientnet_lite2',\n", + " 'efficientnet_lite3',\n", + " 'efficientnet_lite4',\n", + " 'efficientnet_x_b3',\n", + " 'efficientnet_x_b5',\n", + " 'efficientnetv2_l',\n", + " 'efficientnetv2_m',\n", + " 'efficientnetv2_rw_m',\n", + " 'efficientnetv2_rw_s',\n", + " 'efficientnetv2_rw_t',\n", + " 'efficientnetv2_s',\n", + " 'efficientnetv2_xl',\n", + " 'efficientvit_b0',\n", + " 'efficientvit_b1',\n", + " 'efficientvit_b2',\n", + " 'efficientvit_b3',\n", + " 'efficientvit_l1',\n", + " 'efficientvit_l2',\n", + " 'efficientvit_l3',\n", + " 'efficientvit_m0',\n", + " 'efficientvit_m1',\n", + " 'efficientvit_m2',\n", + " 'efficientvit_m3',\n", + " 'efficientvit_m4',\n", + " 'efficientvit_m5',\n", + " 'ese_vovnet19b_dw',\n", + " 'ese_vovnet19b_slim',\n", + " 'ese_vovnet19b_slim_dw',\n", + " 'ese_vovnet39b',\n", + " 'ese_vovnet39b_evos',\n", + " 'ese_vovnet57b',\n", + " 'ese_vovnet99b',\n", + " 'eva02_base_patch14_224',\n", + " 'eva02_base_patch14_448',\n", + " 'eva02_base_patch16_clip_224',\n", + " 'eva02_enormous_patch14_clip_224',\n", + " 'eva02_large_patch14_224',\n", + " 'eva02_large_patch14_448',\n", + " 'eva02_large_patch14_clip_224',\n", + " 'eva02_large_patch14_clip_336',\n", + " 'eva02_small_patch14_224',\n", + " 'eva02_small_patch14_336',\n", + " 'eva02_tiny_patch14_224',\n", + " 'eva02_tiny_patch14_336',\n", + " 'eva_giant_patch14_224',\n", + " 'eva_giant_patch14_336',\n", + " 'eva_giant_patch14_560',\n", + " 'eva_giant_patch14_clip_224',\n", + " 'eva_large_patch14_196',\n", + " 'eva_large_patch14_336',\n", + " 'fasternet_l',\n", + " 'fasternet_m',\n", + " 'fasternet_s',\n", + " 'fasternet_t0',\n", + " 'fasternet_t1',\n", + " 'fasternet_t2',\n", + " 'fastvit_ma36',\n", + " 'fastvit_mci0',\n", + " 'fastvit_mci1',\n", + " 'fastvit_mci2',\n", + " 'fastvit_s12',\n", + " 'fastvit_sa12',\n", + " 'fastvit_sa24',\n", + " 'fastvit_sa36',\n", + " 'fastvit_t8',\n", + " 'fastvit_t12',\n", + " 'fbnetc_100',\n", + " 'fbnetv3_b',\n", + " 'fbnetv3_d',\n", + " 'fbnetv3_g',\n", + " 'flexivit_base',\n", + " 'flexivit_large',\n", + " 'flexivit_small',\n", + " 'focalnet_base_lrf',\n", + " 'focalnet_base_srf',\n", + " 'focalnet_huge_fl3',\n", + " 'focalnet_huge_fl4',\n", + " 'focalnet_large_fl3',\n", + " 'focalnet_large_fl4',\n", + " 'focalnet_small_lrf',\n", + " 'focalnet_small_srf',\n", + " 'focalnet_tiny_lrf',\n", + " 'focalnet_tiny_srf',\n", + " 'focalnet_xlarge_fl3',\n", + " 'focalnet_xlarge_fl4',\n", + " 'gc_efficientnetv2_rw_t',\n", + " 'gcresnet33ts',\n", + " 'gcresnet50t',\n", + " 'gcresnext26ts',\n", + " 'gcresnext50ts',\n", + " 'gcvit_base',\n", + " 'gcvit_small',\n", + " 'gcvit_tiny',\n", + " 'gcvit_xtiny',\n", + " 'gcvit_xxtiny',\n", + " 'gernet_l',\n", + " 'gernet_m',\n", + " 'gernet_s',\n", + " 'ghostnet_050',\n", + " 'ghostnet_100',\n", + " 'ghostnet_130',\n", + " 'ghostnetv2_100',\n", + " 'ghostnetv2_130',\n", + " 'ghostnetv2_160',\n", + " 'ghostnetv3_050',\n", + " 'ghostnetv3_100',\n", + " 'ghostnetv3_130',\n", + " 'ghostnetv3_160',\n", + " 'gmixer_12_224',\n", + " 'gmixer_24_224',\n", + " 'gmlp_b16_224',\n", + " 'gmlp_s16_224',\n", + " 'gmlp_ti16_224',\n", + " 'halo2botnet50ts_256',\n", + " 'halonet26t',\n", + " 'halonet50ts',\n", + " 'halonet_h1',\n", + " 'haloregnetz_b',\n", + " 'hardcorenas_a',\n", + " 'hardcorenas_b',\n", + " 'hardcorenas_c',\n", + " 'hardcorenas_d',\n", + " 'hardcorenas_e',\n", + " 'hardcorenas_f',\n", + " 'hgnet_base',\n", + " 'hgnet_small',\n", + " 'hgnet_tiny',\n", + " 'hgnetv2_b0',\n", + " 'hgnetv2_b1',\n", + " 'hgnetv2_b2',\n", + " 'hgnetv2_b3',\n", + " 'hgnetv2_b4',\n", + " 'hgnetv2_b5',\n", + " 'hgnetv2_b6',\n", + " 'hiera_base_224',\n", + " 'hiera_base_abswin_256',\n", + " 'hiera_base_plus_224',\n", + " 'hiera_huge_224',\n", + " 'hiera_large_224',\n", + " 'hiera_small_224',\n", + " 'hiera_small_abswin_256',\n", + " 'hiera_tiny_224',\n", + " 'hieradet_small',\n", + " 'hrnet_w18',\n", + " 'hrnet_w18_small',\n", + " 'hrnet_w18_small_v2',\n", + " 'hrnet_w18_ssld',\n", + " 'hrnet_w30',\n", + " 'hrnet_w32',\n", + " 'hrnet_w40',\n", + " 'hrnet_w44',\n", + " 'hrnet_w48',\n", + " 'hrnet_w48_ssld',\n", + " 'hrnet_w64',\n", + " 'inception_next_atto',\n", + " 'inception_next_base',\n", + " 'inception_next_small',\n", + " 'inception_next_tiny',\n", + " 'inception_resnet_v2',\n", + " 'inception_v3',\n", + " 'inception_v4',\n", + " 'lambda_resnet26rpt_256',\n", + " 'lambda_resnet26t',\n", + " 'lambda_resnet50ts',\n", + " 'lamhalobotnet50ts_256',\n", + " 'lcnet_035',\n", + " 'lcnet_050',\n", + " 'lcnet_075',\n", + " 'lcnet_100',\n", + " 'lcnet_150',\n", + " 'legacy_senet154',\n", + " 'legacy_seresnet18',\n", + " 'legacy_seresnet34',\n", + " 'legacy_seresnet50',\n", + " 'legacy_seresnet101',\n", + " 'legacy_seresnet152',\n", + " 'legacy_seresnext26_32x4d',\n", + " 'legacy_seresnext50_32x4d',\n", + " 'legacy_seresnext101_32x4d',\n", + " 'legacy_xception',\n", + " 'levit_128',\n", + " 'levit_128s',\n", + " 'levit_192',\n", + " 'levit_256',\n", + " 'levit_256d',\n", + " 'levit_384',\n", + " 'levit_384_s8',\n", + " 'levit_512',\n", + " 'levit_512_s8',\n", + " 'levit_512d',\n", + " 'levit_conv_128',\n", + " 'levit_conv_128s',\n", + " 'levit_conv_192',\n", + " 'levit_conv_256',\n", + " 'levit_conv_256d',\n", + " 'levit_conv_384',\n", + " 'levit_conv_384_s8',\n", + " 'levit_conv_512',\n", + " 'levit_conv_512_s8',\n", + " 'levit_conv_512d',\n", + " 'mambaout_base',\n", + " 'mambaout_base_plus_rw',\n", + " 'mambaout_base_short_rw',\n", + " 'mambaout_base_tall_rw',\n", + " 'mambaout_base_wide_rw',\n", + " 'mambaout_femto',\n", + " 'mambaout_kobe',\n", + " 'mambaout_small',\n", + " 'mambaout_small_rw',\n", + " 'mambaout_tiny',\n", + " 'maxvit_base_tf_224',\n", + " 'maxvit_base_tf_384',\n", + " 'maxvit_base_tf_512',\n", + " 'maxvit_large_tf_224',\n", + " 'maxvit_large_tf_384',\n", + " 'maxvit_large_tf_512',\n", + " 'maxvit_nano_rw_256',\n", + " 'maxvit_pico_rw_256',\n", + " 'maxvit_rmlp_base_rw_224',\n", + " 'maxvit_rmlp_base_rw_384',\n", + " 'maxvit_rmlp_nano_rw_256',\n", + " 'maxvit_rmlp_pico_rw_256',\n", + " 'maxvit_rmlp_small_rw_224',\n", + " 'maxvit_rmlp_small_rw_256',\n", + " 'maxvit_rmlp_tiny_rw_256',\n", + " 'maxvit_small_tf_224',\n", + " 'maxvit_small_tf_384',\n", + " 'maxvit_small_tf_512',\n", + " 'maxvit_tiny_pm_256',\n", + " 'maxvit_tiny_rw_224',\n", + " 'maxvit_tiny_rw_256',\n", + " 'maxvit_tiny_tf_224',\n", + " 'maxvit_tiny_tf_384',\n", + " 'maxvit_tiny_tf_512',\n", + " 'maxvit_xlarge_tf_224',\n", + " 'maxvit_xlarge_tf_384',\n", + " 'maxvit_xlarge_tf_512',\n", + " 'maxxvit_rmlp_nano_rw_256',\n", + " 'maxxvit_rmlp_small_rw_256',\n", + " 'maxxvit_rmlp_tiny_rw_256',\n", + " 'maxxvitv2_nano_rw_256',\n", + " 'maxxvitv2_rmlp_base_rw_224',\n", + " 'maxxvitv2_rmlp_base_rw_384',\n", + " 'maxxvitv2_rmlp_large_rw_224',\n", + " 'mixer_b16_224',\n", + " 'mixer_b32_224',\n", + " 'mixer_l16_224',\n", + " 'mixer_l32_224',\n", + " 'mixer_s16_224',\n", + " 'mixer_s32_224',\n", + " 'mixnet_l',\n", + " 'mixnet_m',\n", + " 'mixnet_s',\n", + " 'mixnet_xl',\n", + " 'mixnet_xxl',\n", + " 'mnasnet_050',\n", + " 'mnasnet_075',\n", + " 'mnasnet_100',\n", + " 'mnasnet_140',\n", + " 'mnasnet_small',\n", + " 'mobilenet_edgetpu_100',\n", + " 'mobilenet_edgetpu_v2_l',\n", + " 'mobilenet_edgetpu_v2_m',\n", + " 'mobilenet_edgetpu_v2_s',\n", + " 'mobilenet_edgetpu_v2_xs',\n", + " 'mobilenetv1_100',\n", + " 'mobilenetv1_100h',\n", + " 'mobilenetv1_125',\n", + " 'mobilenetv2_035',\n", + " 'mobilenetv2_050',\n", + " 'mobilenetv2_075',\n", + " 'mobilenetv2_100',\n", + " 'mobilenetv2_110d',\n", + " 'mobilenetv2_120d',\n", + " 'mobilenetv2_140',\n", + " 'mobilenetv3_large_075',\n", + " 'mobilenetv3_large_100',\n", + " 'mobilenetv3_large_150d',\n", + " 'mobilenetv3_rw',\n", + " 'mobilenetv3_small_050',\n", + " 'mobilenetv3_small_075',\n", + " 'mobilenetv3_small_100',\n", + " 'mobilenetv4_conv_aa_large',\n", + " 'mobilenetv4_conv_aa_medium',\n", + " 'mobilenetv4_conv_blur_medium',\n", + " 'mobilenetv4_conv_large',\n", + " 'mobilenetv4_conv_medium',\n", + " 'mobilenetv4_conv_small',\n", + " 'mobilenetv4_conv_small_035',\n", + " 'mobilenetv4_conv_small_050',\n", + " 'mobilenetv4_hybrid_large',\n", + " 'mobilenetv4_hybrid_large_075',\n", + " 'mobilenetv4_hybrid_medium',\n", + " 'mobilenetv4_hybrid_medium_075',\n", + " 'mobilenetv5_300m',\n", + " 'mobilenetv5_300m_enc',\n", + " 'mobilenetv5_base',\n", + " 'mobileone_s0',\n", + " 'mobileone_s1',\n", + " 'mobileone_s2',\n", + " 'mobileone_s3',\n", + " 'mobileone_s4',\n", + " 'mobilevit_s',\n", + " 'mobilevit_xs',\n", + " 'mobilevit_xxs',\n", + " 'mobilevitv2_050',\n", + " 'mobilevitv2_075',\n", + " 'mobilevitv2_100',\n", + " 'mobilevitv2_125',\n", + " 'mobilevitv2_150',\n", + " 'mobilevitv2_175',\n", + " 'mobilevitv2_200',\n", + " 'mvitv2_base',\n", + " 'mvitv2_base_cls',\n", + " 'mvitv2_huge_cls',\n", + " 'mvitv2_large',\n", + " 'mvitv2_large_cls',\n", + " 'mvitv2_small',\n", + " 'mvitv2_small_cls',\n", + " 'mvitv2_tiny',\n", + " 'naflexvit_base_patch16_gap',\n", + " 'naflexvit_base_patch16_map',\n", + " 'naflexvit_base_patch16_par_gap',\n", + " 'naflexvit_base_patch16_parfac_gap',\n", + " 'naflexvit_base_patch16_siglip',\n", + " 'naflexvit_so150m2_patch16_reg1_gap',\n", + " 'naflexvit_so150m2_patch16_reg1_map',\n", + " 'naflexvit_so400m_patch16_siglip',\n", + " 'nasnetalarge',\n", + " 'nest_base',\n", + " 'nest_base_jx',\n", + " 'nest_small',\n", + " 'nest_small_jx',\n", + " 'nest_tiny',\n", + " 'nest_tiny_jx',\n", + " 'nextvit_base',\n", + " 'nextvit_large',\n", + " 'nextvit_small',\n", + " 'nf_ecaresnet26',\n", + " 'nf_ecaresnet50',\n", + " 'nf_ecaresnet101',\n", + " 'nf_regnet_b0',\n", + " 'nf_regnet_b1',\n", + " 'nf_regnet_b2',\n", + " 'nf_regnet_b3',\n", + " 'nf_regnet_b4',\n", + " 'nf_regnet_b5',\n", + " 'nf_resnet26',\n", + " 'nf_resnet50',\n", + " 'nf_resnet101',\n", + " 'nf_seresnet26',\n", + " 'nf_seresnet50',\n", + " 'nf_seresnet101',\n", + " 'nfnet_f0',\n", + " 'nfnet_f1',\n", + " 'nfnet_f2',\n", + " 'nfnet_f3',\n", + " 'nfnet_f4',\n", + " 'nfnet_f5',\n", + " 'nfnet_f6',\n", + " 'nfnet_f7',\n", + " 'nfnet_l0',\n", + " 'pit_b_224',\n", + " 'pit_b_distilled_224',\n", + " 'pit_s_224',\n", + " 'pit_s_distilled_224',\n", + " 'pit_ti_224',\n", + " 'pit_ti_distilled_224',\n", + " 'pit_xs_224',\n", + " 'pit_xs_distilled_224',\n", + " 'pnasnet5large',\n", + " 'poolformer_m36',\n", + " 'poolformer_m48',\n", + " 'poolformer_s12',\n", + " 'poolformer_s24',\n", + " 'poolformer_s36',\n", + " 'poolformerv2_m36',\n", + " 'poolformerv2_m48',\n", + " 'poolformerv2_s12',\n", + " 'poolformerv2_s24',\n", + " 'poolformerv2_s36',\n", + " 'pvt_v2_b0',\n", + " 'pvt_v2_b1',\n", + " 'pvt_v2_b2',\n", + " 'pvt_v2_b2_li',\n", + " 'pvt_v2_b3',\n", + " 'pvt_v2_b4',\n", + " 'pvt_v2_b5',\n", + " 'rdnet_base',\n", + " 'rdnet_large',\n", + " 'rdnet_small',\n", + " 'rdnet_tiny',\n", + " 'regnetv_040',\n", + " 'regnetv_064',\n", + " 'regnetx_002',\n", + " 'regnetx_004',\n", + " 'regnetx_004_tv',\n", + " 'regnetx_006',\n", + " 'regnetx_008',\n", + " 'regnetx_016',\n", + " 'regnetx_032',\n", + " 'regnetx_040',\n", + " 'regnetx_064',\n", + " 'regnetx_080',\n", + " 'regnetx_120',\n", + " 'regnetx_160',\n", + " 'regnetx_320',\n", + " 'regnety_002',\n", + " 'regnety_004',\n", + " 'regnety_006',\n", + " 'regnety_008',\n", + " 'regnety_008_tv',\n", + " 'regnety_016',\n", + " 'regnety_032',\n", + " 'regnety_040',\n", + " 'regnety_040_sgn',\n", + " 'regnety_064',\n", + " 'regnety_080',\n", + " 'regnety_080_tv',\n", + " 'regnety_120',\n", + " 'regnety_160',\n", + " 'regnety_320',\n", + " 'regnety_640',\n", + " 'regnety_1280',\n", + " 'regnety_2560',\n", + " 'regnetz_005',\n", + " 'regnetz_040',\n", + " 'regnetz_040_h',\n", + " 'regnetz_b16',\n", + " 'regnetz_b16_evos',\n", + " 'regnetz_c16',\n", + " 'regnetz_c16_evos',\n", + " 'regnetz_d8',\n", + " 'regnetz_d8_evos',\n", + " 'regnetz_d32',\n", + " 'regnetz_e8',\n", + " 'repghostnet_050',\n", + " 'repghostnet_058',\n", + " 'repghostnet_080',\n", + " 'repghostnet_100',\n", + " 'repghostnet_111',\n", + " 'repghostnet_130',\n", + " 'repghostnet_150',\n", + " 'repghostnet_200',\n", + " 'repvgg_a0',\n", + " 'repvgg_a1',\n", + " 'repvgg_a2',\n", + " 'repvgg_b0',\n", + " 'repvgg_b1',\n", + " 'repvgg_b1g4',\n", + " 'repvgg_b2',\n", + " 'repvgg_b2g4',\n", + " 'repvgg_b3',\n", + " 'repvgg_b3g4',\n", + " 'repvgg_d2se',\n", + " 'repvit_m0_9',\n", + " 'repvit_m1',\n", + " 'repvit_m1_0',\n", + " 'repvit_m1_1',\n", + " 'repvit_m1_5',\n", + " 'repvit_m2',\n", + " 'repvit_m2_3',\n", + " 'repvit_m3',\n", + " 'res2net50_14w_8s',\n", + " 'res2net50_26w_4s',\n", + " 'res2net50_26w_6s',\n", + " 'res2net50_26w_8s',\n", + " 'res2net50_48w_2s',\n", + " 'res2net50d',\n", + " 'res2net101_26w_4s',\n", + " 'res2net101d',\n", + " 'res2next50',\n", + " 'resmlp_12_224',\n", + " 'resmlp_24_224',\n", + " 'resmlp_36_224',\n", + " 'resmlp_big_24_224',\n", + " 'resnest14d',\n", + " 'resnest26d',\n", + " 'resnest50d',\n", + " 'resnest50d_1s4x24d',\n", + " 'resnest50d_4s2x40d',\n", + " 'resnest101e',\n", + " 'resnest200e',\n", + " 'resnest269e',\n", + " 'resnet10t',\n", + " 'resnet14t',\n", + " 'resnet18',\n", + " 'resnet18d',\n", + " 'resnet26',\n", + " 'resnet26d',\n", + " 'resnet26t',\n", + " 'resnet32ts',\n", + " 'resnet33ts',\n", + " 'resnet34',\n", + " 'resnet34d',\n", + " 'resnet50',\n", + " 'resnet50_clip',\n", + " 'resnet50_clip_gap',\n", + " 'resnet50_gn',\n", + " 'resnet50_mlp',\n", + " 'resnet50c',\n", + " 'resnet50d',\n", + " 'resnet50s',\n", + " 'resnet50t',\n", + " 'resnet50x4_clip',\n", + " 'resnet50x4_clip_gap',\n", + " 'resnet50x16_clip',\n", + " 'resnet50x16_clip_gap',\n", + " 'resnet50x64_clip',\n", + " 'resnet50x64_clip_gap',\n", + " 'resnet51q',\n", + " 'resnet61q',\n", + " 'resnet101',\n", + " 'resnet101_clip',\n", + " 'resnet101_clip_gap',\n", + " 'resnet101c',\n", + " 'resnet101d',\n", + " 'resnet101s',\n", + " 'resnet152',\n", + " 'resnet152c',\n", + " 'resnet152d',\n", + " 'resnet152s',\n", + " 'resnet200',\n", + " 'resnet200d',\n", + " 'resnetaa34d',\n", + " 'resnetaa50',\n", + " 'resnetaa50d',\n", + " 'resnetaa101d',\n", + " 'resnetblur18',\n", + " 'resnetblur50',\n", + " 'resnetblur50d',\n", + " 'resnetblur101d',\n", + " 'resnetrs50',\n", + " 'resnetrs101',\n", + " 'resnetrs152',\n", + " 'resnetrs200',\n", + " 'resnetrs270',\n", + " 'resnetrs350',\n", + " 'resnetrs420',\n", + " 'resnetv2_18',\n", + " 'resnetv2_18d',\n", + " 'resnetv2_34',\n", + " 'resnetv2_34d',\n", + " 'resnetv2_50',\n", + " 'resnetv2_50d',\n", + " 'resnetv2_50d_evos',\n", + " 'resnetv2_50d_frn',\n", + " 'resnetv2_50d_gn',\n", + " 'resnetv2_50t',\n", + " 'resnetv2_50x1_bit',\n", + " 'resnetv2_50x3_bit',\n", + " 'resnetv2_101',\n", + " 'resnetv2_101d',\n", + " 'resnetv2_101x1_bit',\n", + " 'resnetv2_101x3_bit',\n", + " 'resnetv2_152',\n", + " 'resnetv2_152d',\n", + " 'resnetv2_152x2_bit',\n", + " 'resnetv2_152x4_bit',\n", + " 'resnext26ts',\n", + " 'resnext50_32x4d',\n", + " 'resnext50d_32x4d',\n", + " 'resnext101_32x4d',\n", + " 'resnext101_32x8d',\n", + " 'resnext101_32x16d',\n", + " 'resnext101_32x32d',\n", + " 'resnext101_64x4d',\n", + " 'rexnet_100',\n", + " 'rexnet_130',\n", + " 'rexnet_150',\n", + " 'rexnet_200',\n", + " 'rexnet_300',\n", + " 'rexnetr_100',\n", + " 'rexnetr_130',\n", + " 'rexnetr_150',\n", + " 'rexnetr_200',\n", + " 'rexnetr_300',\n", + " 'sam2_hiera_base_plus',\n", + " 'sam2_hiera_large',\n", + " 'sam2_hiera_small',\n", + " 'sam2_hiera_tiny',\n", + " 'samvit_base_patch16',\n", + " 'samvit_base_patch16_224',\n", + " 'samvit_huge_patch16',\n", + " 'samvit_large_patch16',\n", + " 'sebotnet33ts_256',\n", + " 'sedarknet21',\n", + " 'sehalonet33ts',\n", + " 'selecsls42',\n", + " 'selecsls42b',\n", + " 'selecsls60',\n", + " 'selecsls60b',\n", + " 'selecsls84',\n", + " 'semnasnet_050',\n", + " 'semnasnet_075',\n", + " 'semnasnet_100',\n", + " 'semnasnet_140',\n", + " 'senet154',\n", + " 'sequencer2d_l',\n", + " 'sequencer2d_m',\n", + " 'sequencer2d_s',\n", + " 'seresnet18',\n", + " 'seresnet33ts',\n", + " 'seresnet34',\n", + " 'seresnet50',\n", + " 'seresnet50t',\n", + " 'seresnet101',\n", + " 'seresnet152',\n", + " 'seresnet152d',\n", + " 'seresnet200d',\n", + " 'seresnet269d',\n", + " 'seresnetaa50d',\n", + " 'seresnext26d_32x4d',\n", + " 'seresnext26t_32x4d',\n", + " 'seresnext26ts',\n", + " 'seresnext50_32x4d',\n", + " 'seresnext101_32x4d',\n", + " 'seresnext101_32x8d',\n", + " 'seresnext101_64x4d',\n", + " 'seresnext101d_32x8d',\n", + " 'seresnextaa101d_32x8d',\n", + " 'seresnextaa201d_32x8d',\n", + " 'shvit_s1',\n", + " 'shvit_s2',\n", + " 'shvit_s3',\n", + " 'shvit_s4',\n", + " 'skresnet18',\n", + " 'skresnet34',\n", + " 'skresnet50',\n", + " 'skresnet50d',\n", + " 'skresnext50_32x4d',\n", + " 'spnasnet_100',\n", + " 'starnet_s1',\n", + " 'starnet_s2',\n", + " 'starnet_s3',\n", + " 'starnet_s4',\n", + " 'starnet_s050',\n", + " 'starnet_s100',\n", + " 'starnet_s150',\n", + " 'swiftformer_l1',\n", + " 'swiftformer_l3',\n", + " 'swiftformer_s',\n", + " 'swiftformer_xs',\n", + " 'swin_base_patch4_window7_224',\n", + " 'swin_base_patch4_window12_384',\n", + " 'swin_large_patch4_window7_224',\n", + " 'swin_large_patch4_window12_384',\n", + " 'swin_s3_base_224',\n", + " 'swin_s3_small_224',\n", + " 'swin_s3_tiny_224',\n", + " 'swin_small_patch4_window7_224',\n", + " 'swin_tiny_patch4_window7_224',\n", + " 'swinv2_base_window8_256',\n", + " 'swinv2_base_window12_192',\n", + " 'swinv2_base_window12to16_192to256',\n", + " 'swinv2_base_window12to24_192to384',\n", + " 'swinv2_base_window16_256',\n", + " 'swinv2_cr_base_224',\n", + " 'swinv2_cr_base_384',\n", + " 'swinv2_cr_base_ns_224',\n", + " 'swinv2_cr_giant_224',\n", + " 'swinv2_cr_giant_384',\n", + " 'swinv2_cr_huge_224',\n", + " 'swinv2_cr_huge_384',\n", + " 'swinv2_cr_large_224',\n", + " 'swinv2_cr_large_384',\n", + " 'swinv2_cr_small_224',\n", + " 'swinv2_cr_small_384',\n", + " 'swinv2_cr_small_ns_224',\n", + " 'swinv2_cr_small_ns_256',\n", + " 'swinv2_cr_tiny_224',\n", + " 'swinv2_cr_tiny_384',\n", + " 'swinv2_cr_tiny_ns_224',\n", + " 'swinv2_large_window12_192',\n", + " 'swinv2_large_window12to16_192to256',\n", + " 'swinv2_large_window12to24_192to384',\n", + " 'swinv2_small_window8_256',\n", + " 'swinv2_small_window16_256',\n", + " 'swinv2_tiny_window8_256',\n", + " 'swinv2_tiny_window16_256',\n", + " 'test_byobnet',\n", + " 'test_convnext',\n", + " 'test_convnext2',\n", + " 'test_convnext3',\n", + " 'test_efficientnet',\n", + " 'test_efficientnet_evos',\n", + " 'test_efficientnet_gn',\n", + " 'test_efficientnet_ln',\n", + " 'test_mambaout',\n", + " 'test_nfnet',\n", + " 'test_resnet',\n", + " 'test_vit',\n", + " 'test_vit2',\n", + " 'test_vit3',\n", + " 'test_vit4',\n", + " 'tf_efficientnet_b0',\n", + " 'tf_efficientnet_b1',\n", + " 'tf_efficientnet_b2',\n", + " 'tf_efficientnet_b3',\n", + " 'tf_efficientnet_b4',\n", + " 'tf_efficientnet_b5',\n", + " 'tf_efficientnet_b6',\n", + " 'tf_efficientnet_b7',\n", + " 'tf_efficientnet_b8',\n", + " 'tf_efficientnet_cc_b0_4e',\n", + " 'tf_efficientnet_cc_b0_8e',\n", + " 'tf_efficientnet_cc_b1_8e',\n", + " 'tf_efficientnet_el',\n", + " 'tf_efficientnet_em',\n", + " 'tf_efficientnet_es',\n", + " 'tf_efficientnet_l2',\n", + " 'tf_efficientnet_lite0',\n", + " 'tf_efficientnet_lite1',\n", + " 'tf_efficientnet_lite2',\n", + " 'tf_efficientnet_lite3',\n", + " 'tf_efficientnet_lite4',\n", + " 'tf_efficientnetv2_b0',\n", + " 'tf_efficientnetv2_b1',\n", + " 'tf_efficientnetv2_b2',\n", + " 'tf_efficientnetv2_b3',\n", + " 'tf_efficientnetv2_l',\n", + " 'tf_efficientnetv2_m',\n", + " 'tf_efficientnetv2_s',\n", + " 'tf_efficientnetv2_xl',\n", + " 'tf_mixnet_l',\n", + " 'tf_mixnet_m',\n", + " 'tf_mixnet_s',\n", + " 'tf_mobilenetv3_large_075',\n", + " 'tf_mobilenetv3_large_100',\n", + " 'tf_mobilenetv3_large_minimal_100',\n", + " 'tf_mobilenetv3_small_075',\n", + " 'tf_mobilenetv3_small_100',\n", + " 'tf_mobilenetv3_small_minimal_100',\n", + " 'tiny_vit_5m_224',\n", + " 'tiny_vit_11m_224',\n", + " 'tiny_vit_21m_224',\n", + " 'tiny_vit_21m_384',\n", + " 'tiny_vit_21m_512',\n", + " 'tinynet_a',\n", + " 'tinynet_b',\n", + " 'tinynet_c',\n", + " ...]" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "timm.list_models()" + ] + }, + { + "cell_type": "markdown", + "id": "e965b948", + "metadata": {}, + "source": [ + "# Test Inferance" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "d72de444", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "652d208732db40d08a4f1ca5565ed0eb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/31537 [00:00) tensor(5.7054, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.6463, -11.6919, -12.0322, -9.4530, -15.9026, 8.7405],\n", + " [-18.7760, -18.9497, -19.5558, -13.0709, -24.5872, 13.3310],\n", + " [-17.6199, -15.1691, -16.6451, -13.0334, -22.5385, 12.1574],\n", + " [ -9.9132, -10.7750, -12.4069, -8.6010, -15.7670, 8.2736],\n", + " [-15.9893, -15.6640, -14.2749, -10.5522, -20.2389, 11.0540]],\n", + " device='cuda:0', grad_fn=) tensor(4.2167, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.7153, -19.8876, -22.0921, -15.6335, -25.7065, 17.0069],\n", + " [-15.3676, -14.2467, -15.2342, -9.8990, -18.9938, 9.4647],\n", + " [-17.9792, -15.3863, -18.0587, -13.0819, -21.3875, 14.6569],\n", + " [-18.6688, -14.4102, -19.5750, -11.4955, -21.4084, 12.8969],\n", + " [-20.7088, -16.0413, -19.5438, -11.7129, -20.8527, 13.2308]],\n", + " device='cuda:0', grad_fn=) tensor(5.4232, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.2884, -13.2256, -19.0231, -10.5247, -19.4325, 12.5690],\n", + " [-16.4247, -14.8005, -16.2628, -7.1007, -17.1003, 9.8032],\n", + " [-16.3833, -13.2512, -17.6678, -11.7785, -19.4521, 11.2706],\n", + " [-17.9852, -16.8564, -19.0888, -14.4766, -21.9142, 14.0923],\n", + " [-22.7285, -20.0488, -23.6423, -16.3384, -26.4886, 16.9303]],\n", + " device='cuda:0', grad_fn=) tensor(5.2159, device='cuda:0',\n", + " grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_817618/3868621331.py:10: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments.\n", + " results.append(np.array(torch.argmax(out[\"logits\"], dim=1).cpu()))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-25.6627, -19.1187, -26.0079, -17.6518, -27.8401, 17.9818],\n", + " [-20.4639, -14.6373, -21.0470, -13.1323, -21.8942, 14.0545],\n", + " [-17.6472, -12.9593, -16.8312, -13.0439, -17.8351, 12.1448],\n", + " [-24.5938, -19.4026, -23.6713, -15.1207, -27.0474, 16.7659],\n", + " [-23.4175, -18.7242, -20.0196, -14.1604, -26.2955, 15.2203]],\n", + " device='cuda:0', grad_fn=) tensor(6.2651, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.5956, -18.4819, -21.1396, -15.2815, -24.7799, 15.8402],\n", + " [-22.8801, -18.1315, -20.3435, -14.7207, -23.3664, 14.7429],\n", + " [-25.1979, -21.6955, -26.2941, -18.0929, -27.1179, 17.7363],\n", + " [-27.9811, -22.1019, -29.0579, -17.6141, -32.1800, 18.6680],\n", + " [-22.1971, -18.7451, -21.6321, -15.6149, -19.7802, 13.7537]],\n", + " device='cuda:0', grad_fn=) tensor(6.7198, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-21.9624, -18.6315, -23.3189, -14.7934, -25.5468, 14.9872],\n", + " [-21.1179, -18.9034, -21.8123, -12.0546, -24.4166, 13.5688],\n", + " [-24.9073, -20.3081, -24.3183, -15.9231, -26.2028, 17.9872],\n", + " [-12.4258, -10.7020, -13.1045, -7.9197, -5.1201, 2.4479],\n", + " [-18.6397, -16.2914, -19.6577, -11.1818, -22.3971, 12.8716]],\n", + " device='cuda:0', grad_fn=) tensor(5.3668, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.2335, -11.0925, -18.6173, -10.8272, -20.2059, 11.4161],\n", + " [-12.7528, -15.1303, -13.3891, -10.4232, -18.1346, 11.0176],\n", + " [-23.3070, -20.2974, -24.0004, -15.1481, -28.3255, 17.0426],\n", + " [-19.4940, -16.0547, -19.0815, -11.1551, -21.8457, 13.9961],\n", + " [-16.6287, -13.2187, -16.9061, -10.3401, -20.6072, 11.5725]],\n", + " device='cuda:0', grad_fn=) tensor(5.1154, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.0606, -15.3820, -16.0477, -10.4870, -19.0424, 10.9241],\n", + " [-18.9815, -16.4758, -21.5607, -12.1618, -24.4382, 13.6512],\n", + " [-13.5848, -10.5080, -15.7932, -9.5371, -18.0188, 9.1288],\n", + " [-18.6177, -14.1626, -21.8735, -10.3117, -22.8466, 12.2248],\n", + " [-17.6302, -15.2599, -17.4120, -11.6195, -18.0946, 12.4481]],\n", + " device='cuda:0', grad_fn=) tensor(4.7751, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-27.1740, -22.7134, -27.7187, -18.0239, -26.7234, 18.3074],\n", + " [-23.9139, -19.1329, -23.9814, -16.8671, -25.8635, 17.0360],\n", + " [-16.0124, -13.2089, -15.8934, -13.4931, -17.3414, 11.5393],\n", + " [-19.4569, -11.7860, -18.7430, -13.7046, -20.8822, 13.4296],\n", + " [-11.7545, -11.9697, -16.6999, -11.4972, -15.6422, 9.9268]],\n", + " device='cuda:0', grad_fn=) tensor(5.6184, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.0108, -16.9760, -21.6040, -14.2101, -23.7109, 15.8287],\n", + " [-19.1937, -13.8363, -18.0534, -13.2754, -19.3030, 12.3543],\n", + " [-14.8210, -11.6121, -14.3339, -11.0961, -18.6467, 10.4032],\n", + " [ -8.4009, -7.1207, -7.5226, -4.5164, -10.0468, 3.8913],\n", + " [-18.1036, -15.5290, -19.4369, -14.2597, -17.3314, 12.5013]],\n", + " device='cuda:0', grad_fn=) tensor(4.5181, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.3597, -15.4303, -17.1589, -12.9287, -13.9939, 11.7821],\n", + " [ -9.2004, -10.8517, -14.1135, -8.7450, -14.4098, 5.5793],\n", + " [-15.1323, -14.4030, -16.4938, -10.2227, -17.9987, 11.3838],\n", + " [-20.6831, -19.5578, -20.6505, -14.3506, -22.0386, 14.3788],\n", + " [-17.6160, -14.2330, -19.2402, -10.0904, -19.4387, 10.8952]],\n", + " device='cuda:0', grad_fn=) tensor(4.4672, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-23.3211, -18.3509, -23.9018, -14.1461, -26.0910, 15.0205],\n", + " [-13.2699, -12.6511, -13.3758, -9.7168, -14.7329, 9.0603],\n", + " [-22.0040, -19.1629, -21.3595, -12.0876, -23.2405, 13.2916],\n", + " [-24.0177, -20.6618, -22.4069, -12.7104, -25.9310, 13.9082],\n", + " [-19.7511, -17.3349, -18.1201, -12.2846, -19.2682, 12.4322]],\n", + " device='cuda:0', grad_fn=) tensor(5.5359, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -0.5568, -5.7489, -6.0215, -6.2858, -8.5979, -0.1263],\n", + " [-10.3621, -9.7913, -12.4371, 6.7582, -15.3110, -4.8748],\n", + " [-15.9598, -13.0867, -17.8066, -9.5276, -21.2169, 10.0091],\n", + " [-10.8876, -9.3344, -9.6711, -3.6716, -13.1439, 4.2599],\n", + " [ -3.8170, -4.2192, -8.4727, -3.1024, -8.9307, -0.6347]],\n", + " device='cuda:0', grad_fn=) tensor(2.1419, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.6291, -14.5448, -17.2136, -9.1068, -18.4519, 10.2821],\n", + " [-17.0967, -16.7530, -17.0745, -10.1121, -20.0521, 11.6232],\n", + " [-18.6936, -13.5111, -19.3041, -9.9835, -21.2757, 11.5052],\n", + " [-21.4557, -13.5039, -23.8494, -9.9800, -23.7989, 8.7664],\n", + " [-15.7949, -11.9915, -15.0945, -9.9108, -18.1268, 10.7580]],\n", + " device='cuda:0', grad_fn=) tensor(4.7869, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-21.5734, -17.8346, -23.4718, -12.6840, -25.0353, 14.2529],\n", + " [-15.4050, -13.2808, -17.2387, -10.4663, -18.8073, 10.8295],\n", + " [-20.1910, -17.1647, -22.8864, -12.8411, -24.1407, 14.3908],\n", + " [ -9.7707, -11.4383, -10.8691, -2.9999, -12.1758, 5.1991],\n", + " [-15.2149, -12.4283, -15.3228, -9.3511, -18.1027, 11.5676]],\n", + " device='cuda:0', grad_fn=) tensor(4.6150, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.4720, -11.7171, -16.2725, -10.7464, -18.5543, 10.5050],\n", + " [-18.7774, -14.9434, -22.5504, -12.5637, -22.8571, 14.0289],\n", + " [-16.2015, -11.6725, -16.4397, -10.3708, -18.3610, 10.4412],\n", + " [-18.6445, -14.1896, -20.9501, -13.7864, -23.1979, 14.1335],\n", + " [-14.3730, -10.4029, -16.3387, -9.2232, -18.2504, 10.2325]],\n", + " device='cuda:0', grad_fn=) tensor(4.7270, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -8.7186, -3.8538, -8.9463, -2.5103, -8.8867, 0.5764],\n", + " [-10.2975, -8.5321, -11.2477, -3.7862, -12.9197, 3.4381],\n", + " [-18.2361, -13.5092, -18.9623, -11.2675, -19.4533, 11.1002],\n", + " [-14.5296, -16.7200, -17.5984, 9.4304, -20.5124, -7.2371],\n", + " [-13.5946, -7.9236, -14.6383, -7.1874, -14.5220, 7.2951]],\n", + " device='cuda:0', grad_fn=) tensor(3.2606, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.2059, -13.2659, -13.7858, -8.5993, -15.4528, 9.3710],\n", + " [-23.1780, -21.0078, -23.8441, -16.4174, -24.6529, 15.7785],\n", + " [-21.8683, -19.6224, -21.2988, -14.3918, -24.7637, 14.9617],\n", + " [-25.4250, -21.2676, -26.0800, -15.8489, -28.4327, 17.8570],\n", + " [-12.0469, -10.4457, -15.7378, 11.8434, -19.0611, -11.2000]],\n", + " device='cuda:0', grad_fn=) tensor(5.5179, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -2.0269, -6.1973, -8.4963, -1.5625, -8.7724, -0.9863],\n", + " [-16.5860, -11.7183, -16.2914, -9.3547, -15.8960, 9.5831],\n", + " [ -8.7699, -9.6847, -9.3823, 1.5408, -9.4744, -1.2493],\n", + " [-10.8589, -6.7981, -7.5986, -7.6300, -11.8273, 4.4653],\n", + " [-22.4308, -17.9468, -22.7414, -12.8748, -24.2002, 13.6484]],\n", + " device='cuda:0', grad_fn=) tensor(3.0335, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.6758, -11.6944, -12.8267, -4.8349, -14.9143, 7.1880],\n", + " [-17.4477, -12.3160, -18.5011, -5.0701, -16.5591, 5.0563],\n", + " [-14.9894, -13.0718, -17.3287, -11.7971, -16.9851, 11.3770],\n", + " [-11.8078, -10.0448, -12.9544, -7.4926, -14.7328, 8.8941],\n", + " [-15.3958, -13.9130, -16.3935, -8.6360, -17.4381, 10.5895]],\n", + " device='cuda:0', grad_fn=) tensor(3.8815, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.5499, -15.5046, -19.7246, -8.8880, -22.9716, 11.3232],\n", + " [-15.8465, -13.5015, -15.6334, -11.6366, -15.8002, 11.3872],\n", + " [-23.9741, -20.7196, -24.7571, -14.9094, -27.2466, 17.6403],\n", + " [-20.3888, -18.1561, -21.1640, -14.1285, -23.7475, 15.0971],\n", + " [-12.1707, -9.5369, -11.9994, -8.1814, -12.0903, 9.1787]],\n", + " device='cuda:0', grad_fn=) tensor(5.1519, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.0370, -15.6951, -19.2186, -11.8599, -22.1417, 13.2959],\n", + " [-21.9732, -13.7963, -21.3215, -12.2286, -24.0093, 13.1995],\n", + " [-14.3827, -12.8521, -16.1328, -9.3755, -19.2373, 9.9412],\n", + " [ -8.0022, -8.7992, -8.0420, -1.2154, -10.9254, 1.7340],\n", + " [ -6.9496, -3.6115, -10.8233, 5.7642, -10.7271, -7.3332]],\n", + " device='cuda:0', grad_fn=) tensor(3.8245, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -6.7554, -7.0597, -0.5188, -1.4020, -8.4866, -4.5663],\n", + " [-10.3654, -9.4010, -9.8211, 8.1145, -13.0029, -7.4530],\n", + " [-18.1601, -16.8406, -18.7227, -11.4220, -22.2714, 14.0549],\n", + " [-16.6357, -12.7993, -17.6217, -10.5040, -18.1779, 12.7005],\n", + " [-17.6639, -14.5238, -16.8750, -11.1519, -19.1025, 11.6848]],\n", + " device='cuda:0', grad_fn=) tensor(3.8945, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.2514, -14.1619, -21.5101, -13.2114, -21.2256, 12.6599],\n", + " [-18.9201, -17.4807, -18.1632, -12.4249, -19.3518, 14.1659],\n", + " [-13.6821, -8.9911, -15.3710, -10.3057, -17.3289, 7.6996],\n", + " [-13.7279, -15.6886, -15.3964, -9.7234, -17.6351, 11.8476],\n", + " [-16.0148, -12.6544, -15.3582, -10.5735, -14.4972, 10.5639]],\n", + " device='cuda:0', grad_fn=) tensor(4.6511, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.9476, -10.9807, -13.1031, -6.1276, -13.7551, 8.2257],\n", + " [-11.6462, -10.5670, -11.0535, -6.0647, -9.0851, 6.7473],\n", + " [-19.4379, -16.9751, -20.5273, -11.1873, -24.7015, 13.0367],\n", + " [-20.7668, -17.2671, -22.8714, -13.8839, -26.0097, 15.3822],\n", + " [-19.0570, -14.3360, -20.7219, -11.2066, -23.1379, 11.5094]],\n", + " device='cuda:0', grad_fn=) tensor(4.6254, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.3763, -16.9047, -18.2432, -13.8114, -22.4242, 13.9286],\n", + " [-14.7076, -12.3572, -18.3071, -10.6855, -19.9058, 11.4350],\n", + " [-18.1030, -16.7649, -18.3708, -12.5672, -22.5378, 14.4404],\n", + " [-15.3892, -12.1223, -14.2400, -8.7707, -16.4338, 9.5381],\n", + " [-14.6568, -10.0112, -14.0750, -8.8746, -15.2421, 9.4534]],\n", + " device='cuda:0', grad_fn=) tensor(4.6676, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.4078, -17.8550, -20.6478, -14.4612, -23.0387, 15.9179],\n", + " [-13.3120, -9.8104, -12.5739, -10.1389, -16.2885, 10.0250],\n", + " [-20.0272, -17.4726, -22.0632, -11.8177, -23.3084, 13.4900],\n", + " [-15.0457, -12.4919, -15.7207, -8.3389, -17.5007, 9.6556],\n", + " [-15.3929, -12.7397, -15.3918, -9.7167, -17.0744, 9.1469]],\n", + " device='cuda:0', grad_fn=) tensor(4.7141, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-21.4495, -14.4396, -23.4651, -12.4692, -25.5561, 13.2279],\n", + " [-15.3237, -14.3198, -16.5649, -10.8595, -20.9813, 10.7648],\n", + " [-12.9336, -9.1191, -13.1205, -9.8584, -15.9879, 9.0995],\n", + " [-25.2751, -21.7327, -25.3630, -15.0115, -29.2816, 17.0471],\n", + " [-25.0514, -22.7729, -25.8643, -16.6358, -29.2843, 18.1631]],\n", + " device='cuda:0', grad_fn=) tensor(5.6112, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.0194, -12.0820, -16.9358, -8.5514, -19.4512, 10.4406],\n", + " [-18.7939, -13.3192, -18.1140, -12.2072, -21.1129, 12.7919],\n", + " [-12.5711, -9.5236, -12.8724, -7.7499, -12.7376, 7.8732],\n", + " [-20.3983, -15.0184, -22.3984, -14.6905, -23.4612, 14.3706],\n", + " [-17.7230, -9.5240, -19.2937, -11.6319, -18.9726, 11.3751]],\n", + " device='cuda:0', grad_fn=) tensor(4.7453, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.1723, -16.2675, -22.2052, -13.0983, -24.3365, 14.1456],\n", + " [-11.0697, -10.0703, -13.1738, -4.8186, -14.3974, 6.8464],\n", + " [-11.2727, 7.5081, -13.0219, -8.6593, -12.1459, -6.8008],\n", + " [-21.6858, -17.4800, -23.2178, -11.7569, -23.6457, 13.6635],\n", + " [-17.5239, -15.7766, -19.1453, -12.8879, -20.5150, 14.7840]],\n", + " device='cuda:0', grad_fn=) tensor(4.5894, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-11.8716, -11.6161, -14.2173, -5.5436, -13.3707, 7.0785],\n", + " [-13.5630, -10.6320, -13.3879, -9.0720, -13.2748, 9.1122],\n", + " [-18.8632, -15.0208, -18.3854, -13.8823, -20.2418, 13.6420],\n", + " [-21.1739, -16.9934, -23.7775, -14.9724, -24.1337, 15.7189],\n", + " [-17.9503, -13.8916, -18.0888, -12.2282, -19.3619, 13.0755]],\n", + " device='cuda:0', grad_fn=) tensor(4.7351, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.3899, -4.0264, -12.1425, -8.2236, -13.6513, 5.4770],\n", + " [-13.4120, -8.9059, -15.7140, -9.5008, -16.5135, 9.1395],\n", + " [-16.1756, -14.4034, -15.9220, -10.1733, -18.8613, 11.8756],\n", + " [-19.3965, -18.8583, -20.2657, -12.2348, -24.9079, 14.2655],\n", + " [-13.7035, -11.6843, -13.8681, -8.0009, -15.3420, 8.7375]],\n", + " device='cuda:0', grad_fn=) tensor(4.0865, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.7125, -11.5141, -20.1644, -11.2324, -19.4652, 10.3156],\n", + " [-13.7782, -10.2003, -15.1902, -8.1787, -15.0809, 10.4336],\n", + " [-14.4018, -13.0009, -14.9035, -9.9122, -16.2714, 11.1745],\n", + " [ -9.4523, -12.1932, -13.5959, -8.6677, -15.6577, 8.3429],\n", + " [-12.4923, -11.9569, -16.6663, -11.0034, -18.0326, 11.2764]],\n", + " device='cuda:0', grad_fn=) tensor(3.9127, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.5043, -17.3827, -22.0365, -13.3728, -24.3745, 15.4590],\n", + " [-16.2307, -12.9686, -18.9715, -12.1647, -19.6928, 12.1505],\n", + " [-16.7359, -15.6334, -19.2490, -11.3302, -21.1625, 13.3982],\n", + " [-19.3899, -18.3584, -21.5848, -14.8992, -22.1001, 15.8507],\n", + " [-19.3322, -17.2064, -19.7197, -12.2162, -22.9659, 12.8671]],\n", + " device='cuda:0', grad_fn=) tensor(5.3973, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.1309, -11.2604, -16.7232, -10.6595, -18.3610, 11.5822],\n", + " [-21.1509, -17.8577, -22.2355, -14.8425, -24.8438, 15.6943],\n", + " [-21.0714, -14.6085, -22.0803, -13.0932, -24.5670, 14.7513],\n", + " [-15.6899, -10.5828, -17.3273, -9.8969, -18.9134, 10.8317],\n", + " [-21.1024, -17.6348, -23.0049, -15.6890, -25.8769, 16.1231]],\n", + " device='cuda:0', grad_fn=) tensor(5.4043, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.7314, -10.7975, -14.8153, -9.2098, -15.8382, 9.7865],\n", + " [-16.6813, -14.4243, -16.8990, -9.1762, -18.8209, 10.4882],\n", + " [-23.9544, -18.8993, -23.8129, -15.6315, -26.5117, 16.6952],\n", + " [-16.1784, -12.9683, -19.9735, -10.5396, -20.8684, 10.9172],\n", + " [ -5.0827, -7.2716, -9.1614, -3.5863, -11.1548, 2.8657]],\n", + " device='cuda:0', grad_fn=) tensor(4.2490, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.5082, -11.0677, -17.9999, -7.9218, -19.9426, 9.0200],\n", + " [-16.0838, -13.3417, -18.7227, -10.9147, -20.9888, 11.8417],\n", + " [-18.6269, -14.0471, -18.8186, -12.1918, -22.2680, 12.7281],\n", + " [-10.7434, -12.5120, -13.7637, 1.5449, -14.9066, -0.4424],\n", + " [-11.3387, -10.9150, -14.1712, -8.4411, -13.7020, 7.9346]],\n", + " device='cuda:0', grad_fn=) tensor(3.8687, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-23.9070, -20.3559, -24.4889, -15.8258, -28.5338, 18.2317],\n", + " [-11.3584, -12.1919, -11.6068, 11.2247, -16.6001, -14.1959],\n", + " [-14.2405, -11.6127, -16.6036, -6.6192, -16.4254, 6.5353],\n", + " [-17.1072, -13.2204, -18.7735, -12.2513, -20.6415, 12.3282],\n", + " [ 2.6260, -7.5613, -5.9034, -4.3944, -9.4653, -3.6368]],\n", + " device='cuda:0', grad_fn=) tensor(3.8349, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.2962, -11.8773, -14.5696, -11.7493, -16.9390, 11.1837],\n", + " [-17.5687, -15.0570, -18.5007, -11.0198, -21.2345, 13.0276],\n", + " [ -8.2983, -9.1486, -11.4206, 4.7790, -13.1592, -1.7166],\n", + " [ -8.1644, -9.9813, -11.8416, -8.2378, -14.3267, 6.9688],\n", + " [-21.5680, -17.1657, -22.8948, -14.1432, -26.0715, 15.8256]],\n", + " device='cuda:0', grad_fn=) tensor(4.0285, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.2102, -12.9836, -14.8315, -10.8962, -18.8475, 11.6788],\n", + " [-16.6749, -14.2858, -18.9702, -14.3616, -18.9991, 13.4779],\n", + " [-19.5277, -15.9372, -22.4681, -14.7164, -22.8174, 15.5797],\n", + " [-16.2891, -14.4581, -21.1394, -14.0207, -20.1127, 12.9275],\n", + " [-13.4849, -13.2919, -16.1890, -9.3673, -19.0055, 10.3136]],\n", + " device='cuda:0', grad_fn=) tensor(4.9055, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.2771, -8.6008, -11.1654, -5.6105, -12.6183, 7.4107],\n", + " [-19.0082, -14.2938, -19.8347, -12.9445, -21.8790, 13.6905],\n", + " [-19.5861, -17.8861, -20.7971, -13.7361, -22.5391, 14.2205],\n", + " [ -9.9408, -6.6259, -10.1789, -4.8387, -9.7895, 5.4335],\n", + " [-24.2889, -22.7694, -25.1084, -15.3963, -28.4358, 16.8528]],\n", + " device='cuda:0', grad_fn=) tensor(4.6909, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.2590, -12.2529, -13.6483, -7.7208, -14.7364, 8.5842],\n", + " [-13.4446, -12.6274, -15.1673, -8.5549, -17.6365, 10.3967],\n", + " [-23.5282, -21.2204, -23.7261, -16.0989, -28.8771, 17.5752],\n", + " [-25.7697, -23.3372, -25.1170, -16.1030, -27.3975, 16.6778],\n", + " [-24.7674, -19.2580, -24.2328, -15.1823, -27.6261, 16.1211]],\n", + " device='cuda:0', grad_fn=) tensor(5.6375, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.7339, -18.3260, -22.7324, -15.5009, -25.4800, 15.9922],\n", + " [-20.9870, -19.1441, -23.7578, -8.6340, -25.0596, 11.2608],\n", + " [-16.9437, -15.1411, -18.1658, -10.8829, -20.7621, 11.7862],\n", + " [-15.5965, -12.6690, -18.0372, -10.1243, -18.7799, 11.1244],\n", + " [-17.5048, -16.0828, -18.9316, -10.9727, -21.0505, 12.4862]],\n", + " device='cuda:0', grad_fn=) tensor(5.1472, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.1076, -14.6969, -17.6454, -10.5643, -18.4186, 11.0711],\n", + " [-23.8322, -18.0463, -22.9464, -14.8553, -27.5256, 15.5891],\n", + " [-20.9450, -19.2244, -19.7388, -12.8142, -26.1809, 13.4994],\n", + " [-19.4913, -15.1529, -19.6476, -11.8483, -22.7901, 12.7063],\n", + " [-13.4831, -12.1989, -14.2512, -10.3662, -17.7110, 9.4714]],\n", + " device='cuda:0', grad_fn=) tensor(5.2732, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.6361, -15.5828, -18.0406, -11.4602, -20.7875, 12.4264],\n", + " [-22.6797, -17.8540, -24.9452, -12.1069, -26.9300, 14.7086],\n", + " [-12.6837, -10.5632, -13.3694, -6.2251, -14.1930, 8.0056],\n", + " [-14.4769, -14.8849, -16.7415, -7.2062, -17.8976, 9.2677],\n", + " [-15.2257, -13.9795, -17.6534, -10.5351, -19.5998, 11.9774]],\n", + " device='cuda:0', grad_fn=) tensor(4.6364, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.8367, -15.8390, -21.6865, -13.6612, -22.6122, 13.8827],\n", + " [-13.1518, -12.6077, -15.2049, -9.6903, -17.9623, 9.2223],\n", + " [-16.5043, -12.6292, -15.8804, -9.7129, -17.5602, 10.6516],\n", + " [-18.5433, -14.8849, -19.3049, -11.9239, -21.6641, 12.5196],\n", + " [-21.6438, -16.3592, -20.4209, -15.6192, -23.5265, 15.3546]],\n", + " device='cuda:0', grad_fn=) tensor(5.0437, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.9467, -13.5462, -17.9762, -12.1280, -19.3299, 11.7681],\n", + " [-17.0266, -14.3495, -18.9394, -13.4108, -20.9126, 13.7638],\n", + " [-21.4019, -19.1468, -25.2928, -16.9116, -26.3731, 17.1162],\n", + " [-26.3484, -23.6832, -27.5170, -17.2330, -28.0154, 18.2451],\n", + " [-16.0462, -13.2988, -16.9302, -12.2843, -14.9032, 11.5992]],\n", + " device='cuda:0', grad_fn=) tensor(5.5421, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.7290, -15.8293, -21.6415, -13.0583, -24.8633, 14.1536],\n", + " [-17.1478, -12.9604, -18.5224, -9.2067, -19.9271, 11.3396],\n", + " [-20.3198, -17.8568, -19.6014, -13.6527, -19.7684, 13.9269],\n", + " [-18.8919, -17.4777, -18.8643, -12.4781, -23.4350, 13.2882],\n", + " [-14.4659, -15.1607, -18.4798, -11.4604, -19.2335, 11.1741]],\n", + " device='cuda:0', grad_fn=) tensor(5.1812, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.6202, -13.9040, -14.9371, -10.6351, -18.3301, 10.9236],\n", + " [-17.5480, -12.7234, -18.4907, -12.3305, -20.6987, 13.0662],\n", + " [-24.0887, -18.2728, -23.3915, -13.7346, -26.1900, 15.9506],\n", + " [-20.6726, -19.0884, -21.0361, -13.8959, -24.8723, 15.4405],\n", + " [-14.2718, -14.0235, -14.0753, -10.1550, -17.4075, 10.3238]],\n", + " device='cuda:0', grad_fn=) tensor(5.3302, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.6710, -17.2879, -18.2629, -11.3739, -22.1345, 14.0307],\n", + " [-22.0841, -18.4903, -22.5544, -13.2804, -24.1217, 14.5825],\n", + " [-12.6307, -11.1834, -13.0976, -6.1837, -13.2162, 6.7875],\n", + " [-15.5691, -15.8008, -13.1474, -10.4582, -17.9349, 10.1040],\n", + " [-15.1532, -12.8108, -15.1631, -9.5148, -18.1769, 10.6894]],\n", + " device='cuda:0', grad_fn=) tensor(4.7435, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.1744, -11.6438, -15.7907, -9.8355, -17.2297, 12.3719],\n", + " [-11.8090, -10.7504, -10.7309, -9.9834, -12.2582, 8.2313],\n", + " [-14.7942, -11.5056, -14.0296, -10.6525, -16.6547, 11.2133],\n", + " [ -6.0504, -3.3838, -2.6464, -2.6269, -6.5703, 1.2754],\n", + " [-17.6665, -14.7388, -19.0898, -13.3054, -19.9522, 13.9037]],\n", + " device='cuda:0', grad_fn=) tensor(3.7637, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -9.6479, -7.9200, -11.8770, -7.9571, -11.9346, 7.5819],\n", + " [-14.0423, -13.3970, -11.3144, -8.6978, -14.2912, 7.6402],\n", + " [-19.8466, -17.1858, -17.6502, -12.6322, -20.5368, 13.1196],\n", + " [-26.9309, -20.4827, -27.0063, -15.4244, -28.0346, 17.7513],\n", + " [-14.3584, -13.3932, -14.8911, -9.2751, -18.0065, 9.8739]],\n", + " device='cuda:0', grad_fn=) tensor(4.6932, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.3330, -9.8182, -18.5348, -11.8453, -18.9183, 10.8044],\n", + " [-15.1557, -13.9038, -15.7074, -10.8582, -18.2332, 11.8055],\n", + " [-18.6996, -14.5197, -20.9992, -12.6252, -20.1872, 13.1674],\n", + " [-22.8469, -19.7574, -24.8654, -15.2763, -26.8487, 17.1085],\n", + " [-10.4986, -8.8861, -10.7186, -5.3894, -13.6085, 6.3734]],\n", + " device='cuda:0', grad_fn=) tensor(4.7266, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.6296, -15.6283, -20.6069, -12.6860, -23.3083, 14.4273],\n", + " [-13.3279, -9.0744, -13.8022, -8.2591, -16.8325, 8.2038],\n", + " [-13.4357, -9.8973, -15.1503, -10.1377, -16.4957, 10.2683],\n", + " [-10.2582, -8.3794, -14.0234, -7.2057, -14.7144, 6.7166],\n", + " [-13.6460, -8.2001, -14.9831, -9.2263, -17.0637, 9.4848]],\n", + " device='cuda:0', grad_fn=) tensor(3.9467, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-11.6963, -9.2469, -12.6617, -8.1706, -14.4895, 9.3311],\n", + " [-20.3952, -18.3361, -21.6297, -13.9980, -25.2609, 15.2594],\n", + " [-22.0856, -18.9060, -20.6000, -14.3744, -24.6768, 15.1909],\n", + " [-15.2122, -10.5237, -15.9923, -11.6820, -18.4171, 10.4842],\n", + " [-13.1434, -11.2104, -9.8275, -10.0713, -15.1064, 8.3303]],\n", + " device='cuda:0', grad_fn=) tensor(4.7043, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.2211, -9.5423, -13.1478, -8.8343, -15.3084, 8.7082],\n", + " [-16.7319, -12.6597, -17.8649, -11.3318, -20.8555, 11.3941],\n", + " [-17.7222, -10.5871, -19.6106, -12.1121, -19.2650, 12.1031],\n", + " [-19.6047, -18.7017, -20.6850, -10.4466, -24.2143, 12.4366],\n", + " [-13.8230, -12.3804, -15.4939, 6.6683, -17.3757, -2.6268]],\n", + " device='cuda:0', grad_fn=) tensor(4.4162, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.1193, -14.8870, -17.5374, -11.4638, -20.5449, 13.0229],\n", + " [-20.6139, -17.4176, -19.7047, -10.7328, -22.7407, 13.2464],\n", + " [-20.6287, -18.4637, -22.4636, -15.8730, -25.1109, 17.0365],\n", + " [-15.1208, -13.3592, -16.8950, -9.7724, -18.1916, 10.8109],\n", + " [-21.5424, -18.0085, -20.5713, -12.5505, -23.4866, 15.6004]],\n", + " device='cuda:0', grad_fn=) tensor(5.5247, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.4129, -18.2680, -19.5686, -12.8061, -24.2072, 14.3643],\n", + " [-24.4469, -17.0240, -26.6052, -14.7471, -27.1959, 15.1867],\n", + " [-20.6170, -19.4698, -20.2167, -12.6614, -23.4191, 14.6381],\n", + " [-18.8323, -16.9009, -22.7492, -14.6131, -22.3041, 15.3515],\n", + " [-19.2088, -17.0492, -21.2708, -12.2217, -23.2812, 14.3415]],\n", + " device='cuda:0', grad_fn=) tensor(5.9133, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.8148, -17.4347, -20.4716, -12.1514, -23.6210, 15.4582],\n", + " [-19.1338, -15.6345, -20.4077, -11.0689, -22.6674, 13.5178],\n", + " [-18.6162, -16.1241, -22.1729, -13.1607, -23.0949, 13.3637],\n", + " [-15.0940, -12.3485, -15.2322, -9.0852, -16.6256, 8.9301],\n", + " [-20.9731, -17.2684, -19.6739, -10.5838, -21.9027, 12.4065]],\n", + " device='cuda:0', grad_fn=) tensor(5.2436, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.5214, -15.9404, -20.7522, -11.6856, -22.1833, 12.5102],\n", + " [-14.5146, -12.0837, -12.8237, -12.1591, -16.6726, 10.1114],\n", + " [-16.6550, -16.0756, -16.9363, -12.7888, -20.6022, 13.5074],\n", + " [-13.6076, -9.4316, -15.9310, -9.9720, -16.3886, 11.5593],\n", + " [-20.2257, -16.8778, -22.4005, -13.2786, -21.4215, 14.5956]],\n", + " device='cuda:0', grad_fn=) tensor(4.8603, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.1366, -17.3678, -22.1653, -13.3608, -22.3511, 15.2668],\n", + " [-11.4665, -11.4157, -12.7001, -7.7738, -12.1872, 7.1619],\n", + " [-18.3742, -13.8976, -18.9590, -9.7601, -20.8011, 10.7781],\n", + " [-13.9690, -6.9894, -13.6376, -7.2607, -13.0903, 5.0189],\n", + " [-27.2542, -21.2988, -29.1481, -18.2303, -31.1168, 19.5048]],\n", + " device='cuda:0', grad_fn=) tensor(4.9647, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.7593, -19.5813, -23.7843, -15.7244, -24.8389, 17.3628],\n", + " [-24.4455, -20.9051, -25.2435, -15.8720, -28.3832, 16.1786],\n", + " [-19.1644, -14.5254, -20.0183, -13.1755, -24.6596, 14.1023],\n", + " [-16.2799, -10.8130, -17.4141, -8.0549, -18.6652, 8.5796],\n", + " [-20.8250, -17.1108, -22.6241, -13.8311, -26.5337, 15.1030]],\n", + " device='cuda:0', grad_fn=) tensor(5.8267, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.8580, -19.5403, -21.5024, -11.0178, -24.2921, 13.4842],\n", + " [-16.8547, -13.1214, -18.1288, -10.4439, -20.0097, 11.6236],\n", + " [-18.8804, -16.6998, -20.2242, -12.0481, -23.1131, 13.1611],\n", + " [-13.9925, -12.4996, -15.5666, -11.0671, -17.2274, 10.1886],\n", + " [-13.9824, -12.9184, -15.4214, -7.4420, -16.6134, 8.4947]],\n", + " device='cuda:0', grad_fn=) tensor(4.6840, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.2970, -14.2517, -17.7940, -11.0901, -20.4699, 12.0047],\n", + " [-22.3773, -17.9124, -24.3739, -13.9360, -27.1090, 16.0533],\n", + " [-16.8401, -15.5732, -19.3255, -12.9068, -22.7579, 12.8500],\n", + " [-16.9204, -15.4857, -19.5650, -11.0775, -21.3156, 12.2616],\n", + " [-18.7898, -11.4660, -22.6002, -13.2158, -23.7073, 12.9592]],\n", + " device='cuda:0', grad_fn=) tensor(5.2451, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.3884, -13.5856, -17.9079, -9.7859, -17.7833, 11.7924],\n", + " [-19.8772, -14.3560, -22.5769, -11.7164, -22.1028, 12.6778],\n", + " [-16.5265, -15.7999, -19.0882, -11.4491, -20.8517, 13.6349],\n", + " [-16.6222, -15.9760, -16.6970, -11.4053, -20.2419, 12.7633],\n", + " [-19.3569, -12.0447, -22.2062, -14.7031, -22.1541, 13.1451]],\n", + " device='cuda:0', grad_fn=) tensor(5.0928, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-32.9812, -27.4990, -33.0810, -19.9676, -37.1122, 22.7010],\n", + " [-25.1412, -22.5083, -26.1246, -17.9399, -30.2328, 19.0551],\n", + " [-18.1044, -17.4468, -16.1489, -12.2034, -21.7271, 13.3791],\n", + " [-22.0484, -18.7400, -23.6264, -15.1139, -25.0772, 16.2556],\n", + " [-21.4765, -19.0986, -21.7936, -14.6659, -25.0793, 16.4497]],\n", + " device='cuda:0', grad_fn=) tensor(6.9197, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.9703, -12.5283, -15.9123, -10.3607, -16.9043, 11.8670],\n", + " [-14.7851, -11.3858, -14.6984, -8.0111, -17.8239, 9.1729],\n", + " [-23.3414, -16.4799, -23.8550, -14.2612, -25.1833, 15.5180],\n", + " [-11.8542, -14.9106, -17.7117, -10.1824, -19.5678, 9.4499],\n", + " [-16.5264, -11.4633, -14.9103, -13.0815, -16.8892, 11.1628]],\n", + " device='cuda:0', grad_fn=) tensor(4.6550, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-24.5753, -19.0693, -25.6013, -16.3616, -29.0606, 16.9855],\n", + " [-18.4498, -16.5574, -19.0798, -12.9496, -22.0195, 15.4385],\n", + " [-17.8592, -15.2445, -20.2465, -10.8744, -21.1017, 12.4335],\n", + " [-16.3447, -16.2394, -20.7919, -11.0883, -22.6917, 11.9366],\n", + " [-20.8857, -16.6305, -22.3670, -13.4147, -25.3281, 15.1691]],\n", + " device='cuda:0', grad_fn=) tensor(5.6693, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.3107, -17.2678, -24.4074, -15.4727, -27.0266, 16.8336],\n", + " [-18.4084, -14.8985, -19.2755, -12.0711, -21.3660, 14.0156],\n", + " [-18.7173, -18.2618, -20.5120, -13.1939, -24.9243, 13.7050],\n", + " [-18.5922, -13.9030, -18.8118, -10.7443, -20.9841, 11.6499],\n", + " [-18.7843, -15.1183, -20.3662, -9.4882, -22.0089, 11.7991]],\n", + " device='cuda:0', grad_fn=) tensor(5.4939, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.5255, -12.5962, -17.4335, -11.1467, -18.9626, 11.4310],\n", + " [-18.3658, -15.5181, -19.8912, -12.0178, -22.4920, 14.0607],\n", + " [-17.5977, -14.6222, -19.6648, -11.2494, -21.4487, 12.7164],\n", + " [-10.3094, -10.8707, -10.0709, -7.6031, -13.0766, 7.3406],\n", + " [-18.6899, -15.6591, -20.4259, -11.2657, -22.5217, 13.6514]],\n", + " device='cuda:0', grad_fn=) tensor(4.6230, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.0740, -17.1642, -19.3914, -11.9442, -23.4112, 13.1909],\n", + " [-17.8197, -14.5614, -19.1308, -11.5887, -20.4041, 12.1787],\n", + " [-19.0360, -13.7985, -18.8315, -12.2055, -22.0558, 12.9905],\n", + " [-16.9382, -12.0843, -19.4143, -9.3930, -20.9384, 10.6815],\n", + " [-17.4838, -15.3408, -20.2609, -9.5611, -21.3227, 11.1091]],\n", + " device='cuda:0', grad_fn=) tensor(4.9834, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.6968, -11.7862, -13.2379, -6.2418, -14.6098, 7.1439],\n", + " [-16.5674, -15.7555, -18.6584, -8.0537, -20.0214, 11.4463],\n", + " [-12.7221, -9.2729, -17.5542, -10.4586, -17.9921, 6.7436],\n", + " [-21.6593, -19.3343, -23.6112, -12.6774, -27.3476, 14.4801],\n", + " [-21.3644, -17.6912, -22.5517, -14.6001, -26.5790, 15.4281]],\n", + " device='cuda:0', grad_fn=) tensor(4.6085, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.8990, -17.7780, -21.0489, -13.2314, -25.6239, 14.9066],\n", + " [-19.5994, -18.4150, -22.4583, -11.8884, -24.9489, 14.8532],\n", + " [-19.2462, -17.4901, -21.9835, -10.9877, -23.2811, 13.5844],\n", + " [-15.5826, -15.2329, -19.3137, -9.8560, -20.8139, 10.5707],\n", + " [-14.5252, -10.7643, -17.0969, -8.3940, -17.8676, 9.7440]],\n", + " device='cuda:0', grad_fn=) tensor(5.1171, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-23.9719, -20.1234, -28.4602, -17.6994, -28.3996, 17.4066],\n", + " [-18.7809, -14.9155, -19.8409, -11.6401, -21.6021, 13.9610],\n", + " [-18.5305, -14.9233, -19.0456, -13.4152, -20.3604, 12.9150],\n", + " [-20.7417, -18.2628, -21.6057, -15.5778, -23.2930, 16.6837],\n", + " [-17.9475, -14.1712, -19.6554, -13.3653, -21.2943, 13.8216]],\n", + " device='cuda:0', grad_fn=) tensor(5.8254, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.7057, -18.9001, -19.3759, -13.4884, -24.9231, 14.8078],\n", + " [-16.1870, -12.3184, -15.6544, -11.0296, -18.6483, 10.9616],\n", + " [-15.8026, -11.9614, -16.7288, -12.8908, -19.2097, 11.4760],\n", + " [ -9.8667, -12.8459, -12.7908, -8.4372, -15.7941, 8.5906],\n", + " [-18.6315, -15.2141, -21.0891, -13.1284, -23.1672, 13.2186]],\n", + " device='cuda:0', grad_fn=) tensor(4.6416, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.3170, -12.1141, -13.8850, -7.9731, -17.3659, 8.5570],\n", + " [-11.4212, -10.2209, -14.5908, -7.4196, -16.2881, 7.7112],\n", + " [ -9.7739, -12.8511, -12.4381, -7.6655, -14.4346, 7.7773],\n", + " [-17.9632, -14.6031, -20.9872, -13.9909, -22.6986, 13.1252],\n", + " [-20.8732, -18.3798, -24.9761, -16.3162, -26.3668, 16.9583]],\n", + " device='cuda:0', grad_fn=) tensor(4.2160, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.7734, -18.7654, -22.4377, -14.1063, -25.6670, 16.9886],\n", + " [-14.9720, -14.9534, -19.4248, -14.3606, -19.9433, 13.3392],\n", + " [-21.3588, -18.4297, -23.3515, -14.0178, -25.4403, 16.4733],\n", + " [-16.7734, -17.7400, -21.6076, -14.7985, -24.4414, 13.4196],\n", + " [-15.9589, -13.4808, -16.9370, -10.4864, -19.4568, 11.5403]],\n", + " device='cuda:0', grad_fn=) tensor(5.4532, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.3548, -16.7785, -18.4801, -12.0774, -22.7151, 13.5736],\n", + " [-18.7703, -19.9252, -19.5998, -12.3849, -25.6287, 13.0342],\n", + " [-15.2746, -14.4263, -16.7372, -11.9905, -18.9224, 12.9142],\n", + " [-17.7992, -15.0942, -18.7344, -11.5380, -22.3793, 13.2218],\n", + " [-18.4698, -15.9423, -20.4555, -12.8131, -22.0360, 14.7222]],\n", + " device='cuda:0', grad_fn=) tensor(5.2045, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.9748, -13.7058, -19.0213, -11.8479, -21.9417, 13.5272],\n", + " [-14.9068, -15.0676, -17.9662, -12.5415, -20.2507, 11.7762],\n", + " [-20.3708, -17.9092, -20.5850, -14.4878, -24.9056, 13.6956],\n", + " [-22.7280, -18.8601, -22.8030, -14.1650, -25.4472, 16.1378],\n", + " [-17.7265, -16.3266, -17.4803, -13.3785, -21.0708, 13.4891]],\n", + " device='cuda:0', grad_fn=) tensor(5.4111, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.8754, -15.7071, -17.7449, -13.0361, -19.9818, 13.8916],\n", + " [-21.3584, -14.6113, -25.3200, -15.9540, -26.0373, 15.1054],\n", + " [-15.2743, -13.0577, -15.4156, -10.3202, -17.6266, 11.2770],\n", + " [-14.2055, -14.5312, -18.7720, -13.2209, -20.5926, 11.5475],\n", + " [-22.4946, -20.2387, -24.5248, -14.6311, -26.9361, 17.1001]],\n", + " device='cuda:0', grad_fn=) tensor(5.3043, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.5130, -13.8763, -15.8584, -10.6193, -18.6628, 11.4881],\n", + " [-20.3973, -16.6515, -20.9321, -12.2857, -22.1069, 13.7251],\n", + " [-20.5365, -19.3787, -22.4803, -10.7939, -24.9822, 14.0701],\n", + " [-20.2452, -16.8394, -19.4366, -13.7091, -23.2141, 14.7916],\n", + " [-19.1439, -17.5626, -20.2701, -14.3621, -24.5424, 15.2705]],\n", + " device='cuda:0', grad_fn=) tensor(5.5060, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.6656, -15.7138, -18.2888, -10.9559, -22.1569, 13.3772],\n", + " [-17.7336, -15.7905, -16.7400, -10.3683, -21.4636, 12.8227],\n", + " [-19.7585, -18.6270, -20.1307, -12.2257, -24.3634, 15.1827],\n", + " [-19.4242, -17.5527, -20.4448, -14.9366, -23.4163, 15.6954],\n", + " [-20.2690, -17.9048, -19.3549, -14.4854, -22.0681, 14.6792]],\n", + " device='cuda:0', grad_fn=) tensor(5.5869, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-28.1177, -20.9715, -27.2392, -17.2305, -31.3013, 17.7767],\n", + " [-19.1474, -17.2928, -19.7323, -11.9047, -22.8305, 12.9609],\n", + " [-17.2956, -15.3368, -17.8013, -11.9221, -20.5031, 14.0885],\n", + " [-12.6312, -12.2924, -13.4635, -10.5112, -16.6389, 10.2560],\n", + " [-19.0983, -17.7811, -21.8155, -11.8038, -22.7976, 12.1057]],\n", + " device='cuda:0', grad_fn=) tensor(5.4493, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.1600, -13.0184, -16.4458, -11.0808, -18.5597, 11.1225],\n", + " [-19.3812, -16.1166, -18.3470, -13.0626, -21.5304, 14.6689],\n", + " [-21.5873, -17.4331, -20.2584, -13.1244, -24.7632, 14.0848],\n", + " [-23.7779, -18.7419, -24.6766, -15.7794, -27.1349, 15.7003],\n", + " [-23.2882, -18.8450, -23.5957, -13.4268, -25.4233, 15.2923]],\n", + " device='cuda:0', grad_fn=) tensor(5.8354, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.9866, -15.4878, -20.1301, -11.2216, -21.4158, 11.9829],\n", + " [-15.9552, -14.6556, -16.4333, -10.1787, -18.7979, 10.3496],\n", + " [-22.7104, -20.2965, -23.0794, -13.0905, -25.6637, 15.5291],\n", + " [-23.5403, -20.4541, -25.9342, -13.8349, -28.3207, 16.9053],\n", + " [-20.9428, -17.7623, -23.1671, -12.4682, -24.1768, 14.2896]],\n", + " device='cuda:0', grad_fn=) tensor(5.6397, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.0949, -16.7221, -20.5008, -12.0557, -23.2993, 14.4390],\n", + " [-21.4803, -18.2019, -22.1185, -15.1478, -26.1642, 15.3335],\n", + " [-21.4243, -20.2971, -23.3980, -16.1245, -24.5950, 17.2290],\n", + " [-23.4091, -22.6142, -25.9959, -16.4560, -27.4431, 18.6262],\n", + " [-22.3768, -21.1278, -23.1824, -14.7416, -27.8023, 16.7324]],\n", + " device='cuda:0', grad_fn=) tensor(6.3382, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.3863, -13.5600, -19.9461, -11.2188, -20.9589, 11.4459],\n", + " [-14.6725, -11.3649, -16.0441, -10.0943, -19.2804, 10.0089],\n", + " [-16.7444, -13.4433, -17.6785, -8.1769, -20.1085, 10.2394],\n", + " [ -9.7199, -5.7652, -10.0145, -6.7054, -9.8233, 5.9927],\n", + " [-19.9933, -16.5628, -20.9082, -9.7896, -22.6419, 11.9820]],\n", + " device='cuda:0', grad_fn=) tensor(4.2731, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.0861, -11.9654, -17.5644, -10.2035, -18.9767, 11.8931],\n", + " [-15.9785, -13.4638, -17.9364, -9.6540, -20.1270, 10.0593],\n", + " [-20.1442, -17.3396, -21.4897, -13.7703, -23.7502, 13.8689],\n", + " [-11.9642, -12.9455, -14.1612, -8.1744, -16.9217, 9.6575],\n", + " [-21.4610, -17.9839, -23.1693, -12.4272, -24.6516, 14.0490]],\n", + " device='cuda:0', grad_fn=) tensor(4.8721, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.1201, -20.7091, -23.0904, -13.9579, -26.0676, 15.6776],\n", + " [-14.5174, -14.5780, -16.3154, -10.3883, -17.2480, 10.0580],\n", + " [-14.4431, -11.3692, -16.6122, -13.5323, -19.6046, 11.2442],\n", + " [-27.7304, -24.5485, -28.6611, -19.3034, -31.4497, 20.7093],\n", + " [-22.7288, -18.1551, -24.3351, -16.3573, -26.0221, 17.1175]],\n", + " device='cuda:0', grad_fn=) tensor(5.8115, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.4127, -15.5328, -17.9819, -8.9730, -18.7125, 10.6280],\n", + " [-23.7108, -19.6902, -23.8317, -16.1002, -26.8017, 18.3754],\n", + " [-16.6169, -14.0449, -18.5231, -10.0803, -20.8324, 12.0207],\n", + " [-16.1439, -14.9718, -17.6866, -8.8520, -19.4620, 10.4882],\n", + " [-24.1040, -20.5658, -26.6451, -16.7603, -26.7051, 17.2574]],\n", + " device='cuda:0', grad_fn=) tensor(5.5253, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.2382, -14.3147, -18.9593, -12.2502, -19.1783, 12.1915],\n", + " [-14.6757, -12.5675, -16.1349, -7.8402, -16.6600, 8.6357],\n", + " [-19.6062, -15.8615, -20.6970, -9.9591, -24.0498, 10.3431],\n", + " [-14.4842, -13.9449, -16.3620, -4.8229, -17.7464, 6.8868],\n", + " [-16.3014, -13.0761, -17.5243, -10.0741, -17.6989, 11.8242]],\n", + " device='cuda:0', grad_fn=) tensor(4.4399, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.6400, -15.7390, -19.3633, -11.4395, -20.4425, 12.6652],\n", + " [-11.4393, -12.9374, -14.6822, -7.5145, -14.8698, 9.0114],\n", + " [-16.9832, -15.1343, -19.4977, -12.0574, -21.4640, 13.4612],\n", + " [-18.1148, -13.0337, -20.8080, -11.2504, -21.6554, 12.2917],\n", + " [-15.0605, -16.2581, -18.6113, -7.4220, -20.1719, 8.6069]],\n", + " device='cuda:0', grad_fn=) tensor(4.5092, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.1585, -12.7229, -17.1500, -9.8196, -20.7888, 10.2071],\n", + " [-16.0596, -11.8572, -16.5405, -10.5472, -19.1893, 11.8997],\n", + " [-16.8690, -15.1268, -17.2011, -10.8024, -20.6433, 12.2531],\n", + " [-18.7775, -17.6282, -19.9257, -12.5080, -21.3468, 14.5402],\n", + " [-14.9997, -14.0050, -15.8017, -11.3393, -19.5409, 10.9783]],\n", + " device='cuda:0', grad_fn=) tensor(4.7581, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.2756, -8.0491, -11.0042, -4.5234, -10.3225, 4.0230],\n", + " [-16.0398, -11.1164, -15.1176, -10.0344, -18.6164, 10.1628],\n", + " [-12.5132, -10.8211, -14.9827, -9.1742, -17.1821, 9.9145],\n", + " [-18.4439, -12.7465, -21.5300, -11.5545, -23.4562, 11.8798],\n", + " [-21.1030, -19.3982, -19.9244, -13.8952, -24.4015, 15.4825]],\n", + " device='cuda:0', grad_fn=) tensor(4.3289, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.6398, -17.9340, -21.5053, -14.1912, -25.2801, 14.3820],\n", + " [-16.2111, -10.9896, -17.2865, -10.2637, -19.2123, 10.6698],\n", + " [-18.0670, -15.7895, -19.0299, -10.2522, -21.1252, 12.0592],\n", + " [-15.1078, -12.0009, -16.4496, -9.8840, -18.4530, 11.3191],\n", + " [-12.8361, -9.4217, -13.4028, -7.9373, -17.3825, 8.7229]],\n", + " device='cuda:0', grad_fn=) tensor(4.7339, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.4710, -13.7387, -12.5091, -9.9019, -17.1015, 10.5661],\n", + " [-14.6864, -14.2809, -10.7986, -10.0192, -19.4236, 8.6897],\n", + " [-21.4247, -16.7861, -23.6037, -14.4491, -25.0436, 15.8453],\n", + " [-18.1230, -15.4493, -16.6629, -12.0134, -21.2081, 12.0806],\n", + " [-18.1554, -15.5179, -19.0351, -12.4974, -20.5325, 12.8499]],\n", + " device='cuda:0', grad_fn=) tensor(4.8964, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.7177, -13.8416, -15.2486, -9.6244, -18.0467, 10.3239],\n", + " [-24.1385, -18.0729, -26.3439, -15.7565, -28.2851, 17.1401],\n", + " [-17.1934, -14.4554, -17.3910, -11.3049, -20.5639, 11.8824],\n", + " [-17.3398, -14.8960, -17.9027, -8.2123, -20.2282, 9.2361],\n", + " [-19.1341, -14.8859, -21.3039, -11.6237, -20.6105, 13.4214]],\n", + " device='cuda:0', grad_fn=) tensor(5.1176, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.9768, -18.6529, -21.9484, -14.2502, -24.1503, 16.5473],\n", + " [-23.8706, -19.8372, -24.6863, -16.7051, -25.5844, 17.6105],\n", + " [-17.2622, -14.3734, -17.6110, -9.3845, -15.7502, 8.8129],\n", + " [-13.8973, -12.4226, -15.6001, -11.5369, -15.7504, 11.3192],\n", + " [-14.9060, -14.4255, -19.7606, -13.1144, -20.6059, 12.2337]],\n", + " device='cuda:0', grad_fn=) tensor(5.2479, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.6985, -11.5600, -11.3345, -7.3910, -14.6318, 7.7115],\n", + " [-13.4991, -11.5705, -14.6853, -6.5622, -14.5839, 10.1141],\n", + " [ -7.4961, -6.3502, -6.6151, -3.9471, -5.8693, 3.8525],\n", + " [-17.8359, -15.9585, -16.5560, -10.1990, -18.1437, 10.4733],\n", + " [ -8.6907, -9.6264, -9.6112, -1.7396, -10.4617, 3.7027]],\n", + " device='cuda:0', grad_fn=) tensor(3.2103, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.6845, -9.8545, -10.1369, -4.1881, -11.8488, 7.3928],\n", + " [-16.7131, -14.5159, -15.4915, -10.1817, -18.0355, 10.3292],\n", + " [-13.8805, -10.1855, -11.8923, -8.8698, -15.4124, 8.3858],\n", + " [-12.4117, -9.6254, -12.4239, -6.2932, -13.4157, 7.7600],\n", + " [-19.5006, -17.0036, -18.6615, -13.4596, -21.4147, 14.0077]],\n", + " device='cuda:0', grad_fn=) tensor(4.0361, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.3934, -12.1062, -16.5788, -7.5249, -17.8767, 8.9113],\n", + " [-17.2418, -14.8963, -21.0337, -12.5611, -22.3435, 12.9004],\n", + " [-15.2857, -13.2404, -17.3960, -9.4093, -18.8769, 10.5207],\n", + " [-20.7384, -17.1768, -24.1848, -13.4080, -25.2121, 14.0100],\n", + " [-12.2906, -11.0539, -12.2814, -9.7475, -13.5809, 9.4918]],\n", + " device='cuda:0', grad_fn=) tensor(4.5595, device='cuda:0',\n", + " grad_fn=)\n" + ] + } + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "results = []\n", + "model.cuda()\n", + "for i in range(0, ds[\"test\"].shape[0], 5):\n", + " x = TimmInputs(labels=torch.Tensor(ds[\"test\"][i:i+5].labels).cuda(),\n", + " spectrogram=torch.Tensor(np.concat(ds[\"test\"][i:i+5].spectrogram)).unsqueeze(1).cuda())\n", + " out = model(x)\n", + " results.append(np.array(torch.argmax(out[\"logits\"], dim=1).cpu()))\n", + " if i > 500:\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "1268147f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dec69cd5131646a6bcfa0772c7ee371d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/31537 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for i in np.arange(510)[np.concat(results) != 5]:\n", + " print(np.concat(results)[i])\n", + " import IPython\n", + " display(IPython.display.Audio(og_ds[\"test\"][int(i)][\"audio\"][\"path\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "51e71f6b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([ 66, 88, 94, 97, 114, 115, 116, 152, 188, 191, 194, 197, 284])]" + ] + }, + "execution_count": 106, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[np.arange(510)[np.concat(results) != 5]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed62fc90", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "whoot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/whoot_model_training/train.py b/whoot_model_training/train.py index bed027d..5affeb9 100644 --- a/whoot_model_training/train.py +++ b/whoot_model_training/train.py @@ -18,7 +18,7 @@ import yaml from whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments -from whoot_model_training.data_extractor import buowset_extractor +from whoot_model_training.data_extractor import xc_extractor from whoot_model_training.models import TimmModel, TimmInputs, TimmModelConfig from whoot_model_training import CometMLLoggerSupplement @@ -26,7 +26,6 @@ from whoot_model_training.preprocessors.spectrogram_preprocessors import ( SpectrogramParams, ) - # Uncomment for use with data augmentation # from pyha_analyzer.preprocessors import MixItUp, ComposeAudioLabel # from audiomentations import ( @@ -64,10 +63,16 @@ def train(config): config (dict): the config used for training. Defined in yaml file """ # Extract the dataset - ds = buowset_extractor( - metadata_csv=config["metadata_csv"], - parent_path=config["data_path"], - output_path=config["hf_cache_path"], + # ds = buowset_extractor( + # metadata_csv=config["metadata_csv"], + # parent_path=config["data_path"], + # output_path=config["hf_cache_path"], + # ) + + csv_path = "/home/sean/whoot/data/san_diego_xc_aux/xc_meta_aux.json" + ds = xc_extractor( + xc_dataset_json_path=csv_path, + parent_path="/home/sean/whoot/data/san_diego_xc_aux/xeno-canto" ) # Create the model @@ -75,13 +80,13 @@ def train(config): run_name = f"buowset1.1_{model_name}" model_config = TimmModelConfig( - timm_model=model_name, num_classes=ds.get_num_classes() - ) + timm_model=model_name, + num_classes=ds.get_num_classes()) model = TimmModel(model_config) # Preprocessors - # Uncomment if doing work with data augmentation + # # Uncomment if doing work with data augmentation # # Augmentations # wav_augs = ComposeAudioLabel([ # # AddBackgroundNoise( #We don't have background noise yet... @@ -92,17 +97,17 @@ def train(config): # # p=0.8 # # ), # Gain( - # min_gain_db = -12, - # max_gain_db = 12, + # min_gain_db=-12, + # max_gain_db=12, # p = 0.8 # ), - # MixItUp( - # dataset_ref=ds["train"], - # min_snr_db=10, - # max_snr_db=30, - # noise_transform=PolarityInversion(), - # p=0.8 - # ) + # # MixItUp( + # # dataset_ref=ds["train"], + # # min_snr_db=10, + # # max_snr_db=30, + # # noise_transform=PolarityInversion(), + # # p=0.8 + # # ) # ]) spectrogram_params = SpectrogramParams() @@ -161,7 +166,7 @@ def train(config): ) trainer.train() - model.save_pretrained("model_checkpoints/test") + model.save_pretrained("model_checkpoints/xc_aux_testing") def init_env(config: dict): diff --git a/whoot_model_training/whoot_model_training/data_extractor/__init__.py b/whoot_model_training/whoot_model_training/data_extractor/__init__.py index 67f215c..927844b 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/__init__.py +++ b/whoot_model_training/whoot_model_training/data_extractor/__init__.py @@ -10,10 +10,12 @@ ) from .esc50_extractor import esc50_extractor from .raw_audio_extractor import raw_audio_extractor +from .xc_extractor import xc_extractor __all__ = [ "buowset_extractor", "buowset_binary_extractor", "esc50_extractor", - "raw_audio_extractor", + "xc_extractor", + "raw_audio_extractor" ] diff --git a/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py index 6fab5bd..bf557da 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py @@ -18,28 +18,14 @@ import os from dataclasses import dataclass -import numpy as np from datasets import ( load_dataset, Audio, DatasetDict, - ClassLabel, - Sequence, ) -from ..dataset import AudioDataset - - -def one_hot_encode(row: dict, classes: list): - """One hot Encodes a list of labels. - Args: - row (dict): row of data in a dataset containing a labels column - classes: a list of classes - """ - one_hot = np.zeroes(len(classes)) - one_hot[row["labels"]] = 1 - row["labels"] = np.array(one_hot, dtype=float) - return row +from .utils import convert_labeled_dataset_onehot +from ..dataset import AudioDataset @dataclass @@ -84,17 +70,7 @@ def esc50_extractor( dataset = load_dataset("csv", data_files=metadata_csv)["train"] dataset = dataset.rename_column("category", "labels") - dataset = dataset.class_encode_column("labels") - - class_list = dataset.features["labels"].names - - multilabel_class_label = Sequence(ClassLabel(names=class_list)) - - dataset = dataset.map( - lambda row: one_hot_encode(row, class_list) - ).cast_column( - "labels", multilabel_class_label - ) + dataset = convert_labeled_dataset_onehot(dataset) dataset = dataset.add_column( "audio", [ diff --git a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py index 454276c..39a7350 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py @@ -7,11 +7,9 @@ Rather just a placeholder to help inferance work """ - +from math import floor from typing import Any, ClassVar, Union import os -from math import floor - import numpy as np from datasets import ( Audio, @@ -20,10 +18,9 @@ ClassLabel, Sequence, Dataset, - table, + table ) from datasets.features.features import _FEATURE_TYPES, FeatureType - import librosa from tqdm import tqdm import pyarrow as pa @@ -224,9 +221,7 @@ def get_array_chunks_from_memory( print( e, file_path, - "hit EOF too early, likely corrupted", - "| ignoring and continuing" - ) + "failed stat read, reached end of file", "continuing") continue for i in tqdm( range(0, int(floor(clip_length)), chunk_length_sec), diff --git a/whoot_model_training/whoot_model_training/data_extractor/utils.py b/whoot_model_training/whoot_model_training/data_extractor/utils.py new file mode 100644 index 0000000..3e67f60 --- /dev/null +++ b/whoot_model_training/whoot_model_training/data_extractor/utils.py @@ -0,0 +1,35 @@ +"""Utility functions for data extraction and preprocessing.""" +from datasets import ( + Dataset, + ClassLabel, + Sequence, +) + +import numpy as np + + +def one_hot_encode(row: dict, classes: list): + """One hot Encodes a list of labels. + + Args: + row (dict): row of data in a dataset containing a labels column + classes: a list of classes + """ + one_hot = np.zeros(len(classes)) + one_hot[row["labels"]] = 1 + row["labels"] = np.array(one_hot, dtype=float) + return row + + +def convert_labeled_dataset_onehot(dataset: Dataset): + """Dataset with label column to one hot encoded version.""" + dataset = dataset.class_encode_column("labels") + class_list = dataset.features["labels"].names + multilabel_class_label = Sequence(ClassLabel(names=class_list)) + dataset = dataset.map( + lambda row: one_hot_encode(row, class_list) + ).cast_column( + "labels", + multilabel_class_label + ) + return dataset diff --git a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py new file mode 100644 index 0000000..737832a --- /dev/null +++ b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py @@ -0,0 +1,196 @@ +"""Ceates Dataset from the Xeno-Canto Data Downlaoder tool. + +See data_downloader/xc.py +""" + +import os +import shutil +import json +from pathlib import Path +from dataclasses import dataclass +from collections import Counter +from pydub import AudioSegment +import librosa +from datasets import ( + Dataset, + Audio, + DatasetDict, + ClassLabel, + load_from_disk, +) +from ..dataset import AudioDataset +from .utils import ( + convert_labeled_dataset_onehot, +) + + +def filter_by_count(ds, col="en", threshold=10): + """Limit species list to species with some amount of species.""" + count_by_species = Counter(ds[col]) + return ds.filter( + lambda row: count_by_species[row] > threshold, + input_columns=[col] + ) + + +def filter_xc_data(row: dict): + """In personal experience, raw XC data is very messy. + + Some files get coruptted + This intention checks to see if loading files is + possible for the frist place + """ + file_path = row["filepath"] + try: + # Heuristic, if we can load 3 seconds, file is probably okay + # Prevents some files from taking forever + librosa.load(path=file_path, duration=3) + return True + except FileNotFoundError as e: + print(e, file_path) + return False + except IOError as e: + print(e, file_path) + return False + + +def convert_audio_to_flac(row, error_path="bad_files", col="audio"): + """Convert any audio to flac for better compression. + + Args: + row: row from hugging face table + error_path: folder to dump broken files + col: column with audio path + """ + file_path = row[col] + flac_path = Path(file_path).parent / (Path(file_path).stem + ".flac") + # print(file_path, flac_path) + if os.path.exists(flac_path): + row[col] = str(flac_path) + if os.path.exists(file_path): + os.remove(file_path) # Remove origional file, we don't need it + return row + try: + wav_audio = AudioSegment.from_file(file_path) + wav_audio.export(flac_path, format="flac") + except IOError as e: + if os.path.exists(file_path): + os.makedirs(error_path, exist_ok=True) + shutil.move(file_path, error_path) + + print( + "ERROR", + "move to", + os.path.join(error_path, Path(file_path).name), + "ERR MSG:", + e + ) + row[col] = str(os.path.join(error_path, Path(file_path).name)) + return row + row[col] = str(flac_path) + return row + + +@dataclass +class XCParams(): + """Parameters that describe ESC-50. + + validation_fold (int): label for valid split + test_fold (int): label for valid split + sample_rate (int): sample rate of the data + filepath (string): name of column in csv for filepaths + """ + validation_fold = 4 + test_fold = 5 + sample_rate = 44_100 + + +def xc_extractor( + xc_dataset_json_path, + parent_path, + cache_path="data/san_diego_xc_aux/cache", + params: XCParams = XCParams(), + bad_file_path="data/xc_bad_file" +): + """Extracts data collected from the XC downloader. + + XC_dataset_json_path: json outputted from XC downloader + parent_path: path to highest level audio file + cache_path: path to cache hugging + """ + if os.path.exists(cache_path): + return load_from_disk(cache_path) + + with open(xc_dataset_json_path, mode="r", encoding="utf-8") as f: + xc_recordings_paged = json.load(f) + + xc_recordings = [] + for page in xc_recordings_paged: + xc_recordings.extend(page["recordings"]) + + dataset = Dataset.from_list(xc_recordings) + + dataset = dataset.add_column( + "labels", + dataset["en"], + new_fingerprint="labels" + ) + dataset = dataset.class_encode_column("labels") + dataset = convert_labeled_dataset_onehot(dataset) + + dataset = dataset.add_column( + "audio", [ + os.path.join( + parent_path, + file.replace("/", "_") + ) for file in dataset["file-name"] + ] + ) + + # Only accept less than 10 min long clips + # Longer clips seem to courrpt more easily... + # Format is "#:##"" hence length 4 + dataset = dataset.filter( + lambda x: len(x["length"]) == 4 + ) + + # Fix file paths + dataset = dataset.map( + convert_audio_to_flac, + fn_kwargs={"error_path": bad_file_path}, + # num_proc=16 + ) + + dataset = dataset.filter( + lambda x: bad_file_path not in x["audio"], + ) + + dataset = dataset.add_column("filepath", dataset["audio"]) + dataset = dataset.cast_column( + "audio", + Audio(sampling_rate=params.sample_rate) + ) + + dataset = dataset.cast_column( + "en", ClassLabel(names=list(set(dataset["en"]))) + ) + + dataset = filter_by_count(dataset) + + train_test = dataset.train_test_split(0.2, stratify_by_column="en") + test_val = train_test["test"].train_test_split( + 0.2, + stratify_by_column="en" + ) + + dataset = AudioDataset( + DatasetDict({ + "train": train_test["train"], + "valid": test_val["train"], + "test": test_val["test"]}) + ) + + # os.makedirs(cache_path, exist_ok=True) + # dataset.save_to_disk(cache_path) + + return dataset diff --git a/whoot_model_training/whoot_model_training/models/__init__.py b/whoot_model_training/whoot_model_training/models/__init__.py index d380f40..76456fd 100644 --- a/whoot_model_training/whoot_model_training/models/__init__.py +++ b/whoot_model_training/whoot_model_training/models/__init__.py @@ -5,13 +5,25 @@ """ from .timm_model import TimmModel, TimmInputs, TimmModelConfig +from .hf_models import HFModel, HFModelConfig, HFInput from .model import Model, ModelInput, ModelOutput +# from .few_shot_model import ( +# PerchEmbeddingInput, +# PerchFewShotModel, +# FewShotModelConfig +# ) __all__ = [ "TimmModel", "TimmInputs", "TimmModelConfig", + "HFModel", + "HFModelConfig", + "HFInput", "Model", "ModelInput", "ModelOutput", + # "PerchEmbeddingInput", + # "PerchFewShotModel", + # "FewShotModelConfig" ] diff --git a/whoot_model_training/whoot_model_training/models/few_shot_model.py b/whoot_model_training/whoot_model_training/models/few_shot_model.py new file mode 100644 index 0000000..1a0be28 --- /dev/null +++ b/whoot_model_training/whoot_model_training/models/few_shot_model.py @@ -0,0 +1,148 @@ +"""Build a few_shot_learning classifier. + +Inspired by the work of +Jacuzzi, G., Olden, J.D., 2025. +Few-shot transfer learning enables robust acoustic +monitoring of wildlife communities at the landscape scale. +Ecological Informatics 90, 103294. +doi.org/10.1016/j.ecoinf.2025.103294 + +These models convert thier input into an embedding from a large audio model and +do processing on top of that embedding +""" + +# from torch import nn, Tensor +# from perch_hoplite.zoo import model_configs +# from .model import Model, ModelInput, ModelOutput, has_required_inputs + +from transformers import PretrainedConfig +from .model import ModelInput + + +class EmbeddingModel(): + """Wrapper for models which are only intended for embeddings.""" + def embed(self): + """Get embedding.""" + raise NotImplementedError() + + def get_k_neighbors(self): + """Get k nearest neighbors.""" + raise NotImplementedError() + + +class EmbeddingInput(ModelInput): + """Wrapper for ModelInputs that are embeddings.""" + model = EmbeddingModel() + embedding_size = 0 + + def __init__( + self, + labels, + waveform=None, + spectrogram=None + ): + """. + + Args: + labels: label + waveform: np array of sound + spectrogram: 2d array representing sound + """ + super().__init__(labels, waveform, spectrogram) + + # I keep getting this linting error + # But there is not too many function args here + # pylint: disable=too-many-function-args + self["embedding"] = self.model.embed(waveform) + + +# Global variable fore PerchEmbeddings +PERCH_MODEL = None + +# class PerchEmbeddings(EmbeddingModel): +# """Wrapper for getting embeddings from perch.""" + +# # Warning, was running into issues with memory here +# # Early attempts recreated model +# # Hoping using global var only loads it in once +# if perch_model is None: +# perch_model = model_configs.load_model_by_name('perch_8') + +# model = perch_model + +# def embed(self, embeddings): +# """Return embeddings.""" +# return embeddings + + +# class PerchEmbeddingInput(EmbeddingInput): +# """Wrapper for an input into a larger model from perch.""" +# model = PerchEmbeddings() +# embedding_size = 1280 + + +class FewShotModelConfig(PretrainedConfig): + """Config for Timm Model Zoo Models!""" + def __init__( + self, + num_classes=200, + **kwargs + ): + """Creates Config. + + Args: + num_classes: how many species we want to detect + """ + self.num_classes = num_classes + super().__init__(**kwargs) + + +# class PerchFewShotModel(Model, nn.Module): +# """Perch model intergration with pytorch.""" +# def __init__( +# self, +# config: FewShotModelConfig +# ): +# """Init for TimmModel. + +# kwargs: +# timm_model (str): name of model backbone from timms to use, +# Default: "resnet34" +# pretrained (bool): use a pretrained model from timms, +# Default: True +# in_chans (int): number of channels of audio: Default: 1 +# num_classes (int): number of classes in the dataset: Default 6 +# loss (any): custom loss function Default: BCEWithLogitsLoss +# """ +# super().__init__() + +# self.input_format = PerchEmbeddingInput +# self.output_format = ModelOutput + +# self.config = config +# assert config.num_classes > 0 + +# self.linear = nn.Linear( +# self.input_format.embedding_size, +# config.num_classes +# ) + +# self.loss = nn.BCEWithLogitsLoss() + +# @has_required_inputs() +# def forward(self, x: PerchEmbeddingInput): +# """Run model over x!""" +# # Use perch to create embeddings +# embeddings = Tensor( +# x.model.model.embed(x["waveform"].cpu()).embeddings +# ).to(x["waveform"].device) + +# logits = self.linear(embeddings).squeeze(1) +# loss = self.loss(logits, x["labels"]) + +# return ModelOutput( +# logits=logits, +# embeddings=embeddings, +# loss=loss, +# labels=x["labels"] +# ) diff --git a/whoot_model_training/whoot_model_training/models/hf_models.py b/whoot_model_training/whoot_model_training/models/hf_models.py new file mode 100644 index 0000000..bb428c8 --- /dev/null +++ b/whoot_model_training/whoot_model_training/models/hf_models.py @@ -0,0 +1,144 @@ +"""Wrapper around the hugging face model api!""" + +from contextlib import nullcontext + +from transformers import AutoFeatureExtractor, AutoModel, PretrainedConfig +from torch import nn +import torch + +from .model import Model, ModelInput, ModelOutput, has_required_inputs + + +class HFInput(ModelInput): + """Input for Hugging Face Models. + + Specifies TimmModels needs labels and spectrograms that are Tensors + """ + + def __init__(self, + labels=None, + spectrogram=None, + waveform=None, + extractor_path="DBD-research-group/Bird-MAE-Base"): + """Creates TimmInputs. + + Args: + labels: the data's label for this batch + spectrogram: Legacy + waveform: Legacy + extractor_path: Path to hugging face preprocessor + """ + self.feature_extractor = AutoFeatureExtractor.from_pretrained( + extractor_path, + trust_remote_code=True) + super().__init__(labels, waveform, spectrogram) + + def __call__(self, labels, spectrogram=None, waveform=None): + """Create some fake ModelInputs for HFModels. + + Slightly diffrent API for HFInput, when creating a input + Use the preprocessor from hugging face. + """ + mel_spectrogram = self.feature_extractor(waveform) + return ModelInput(labels, waveform=None, spectrogram=mel_spectrogram) + + +class HFModelConfig(PretrainedConfig): + """Config for Timm Model Zoo Models!""" + def __init__( + self, + path: str = "DBD-research-group/Bird-MAE-Huge", + num_classes: int = 6, + embeddings_size: int = 1280, + freeze_backbone: bool = True, + **kwargs + ): + """Creates Config. + + Args: + path (str): url to pull from hf model zoo + num_classes (int): number of classes in dataset, for cls + embeddings_size (int): size of output of model + freeze_backbone (bool): freeze the backbone of a model + """ + self.path = path + self.num_classes = num_classes + self.embeddings_size = embeddings_size + self.freeze_backbone = freeze_backbone + super().__init__(**kwargs) + + +class HFModel(Model, nn.Module): + """Model that uses a timm's model.""" + config_class = HFModelConfig + + def __init__( + self, + config: HFModelConfig + ): + """Init for TimmModel. + + kwargs: + timm_model (str): name of model backbone from timms to use, + Default: "resnet34" + pretrained (bool): use a pretrained model from timms, Default: True + in_chans (int): number of channels of audio: Default: 1 + num_classes (int): number of classes in the dataset: Default 6 + loss (any): custom loss function Default: BCEWithLogitsLoss + """ + super().__init__() + self.input_format = ModelInput + self.output_format = ModelOutput + self.config = config + assert config.num_classes > 0 + + # Deep learning CNN backbone + self.backbone = AutoModel.from_pretrained( + config.path, + trust_remote_code=True + ) + + # Unsure if 1000 is default for all timm models. Need to check this + self.linear = nn.Linear(config.embeddings_size, config.num_classes) + + # different losses if you want to train for different problems + # BCEWithLogitsLoss is default as for Bioacoustics, the problem tends + # multilabel! + # the probability of class A occurring doesn't + # change the probability of Class B + # Many individuals can make calls at the same time! + self.loss = nn.BCEWithLogitsLoss() + + def set_custom_loss(self, loss_fn): + """Set a different loss function. + + For cases where we don't want BCEWithLogitsLoss + + Args: + loss_fn: Function to compute loss, ideally in pytorch + """ + self.loss = loss_fn + + @has_required_inputs() + def forward(self, x: HFInput) -> ModelOutput: + """Model forward function. + + Args: + x: (TimmInputs): The specific input format for Timm Models + + Returns + (ModelOutput): The model output (logits), + latent space representations (embeddings), loss and labels. + """ + with torch.no_grad() if self.config.freeze_backbone else nullcontext(): + embed = self.backbone( + x.spectrogram.to(self.device) + ).last_hidden_state + logits = self.linear(embed) + + return ModelOutput( + logits=logits, + embeddings=embed, + loss=self.loss(logits, x.labels), + labels=x.labels + ) diff --git a/whoot_model_training/whoot_model_training/models/model.py b/whoot_model_training/whoot_model_training/models/model.py index 9b1f450..69c2199 100644 --- a/whoot_model_training/whoot_model_training/models/model.py +++ b/whoot_model_training/whoot_model_training/models/model.py @@ -136,13 +136,14 @@ def __init__( waveform: raw audio signal spectrogram: 2d matrix to represent the waveform """ - super().__init__( - { - "labels": labels, - "waveform": waveform, - "spectrogram": spectrogram - } - ) + super().__init__({ + "labels": labels, + "waveform": waveform, + "spectrogram": spectrogram + }) + self.labels = labels + self.waveform = waveform + self.spectrogram = spectrogram def items(self): """Get all items in dict. diff --git a/whoot_model_training/whoot_model_training/preprocessors/__init__.py b/whoot_model_training/whoot_model_training/preprocessors/__init__.py index d7bb48a..65b121f 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/__init__.py +++ b/whoot_model_training/whoot_model_training/preprocessors/__init__.py @@ -7,7 +7,15 @@ the __get_item__ function of a dataset """ -from .base_preprocessor import MelModelInputPreprocessor -from .spectrogram_preprocessors import BuowMelSpectrogramPreprocessors +from .base_preprocessor import ( + MelModelInputPreprocessor, WaveformInputPreprocessor +) +from .spectrogram_preprocessors import ( + BuowMelSpectrogramPreprocessors +) -__all__ = ["MelModelInputPreprocessor", "BuowMelSpectrogramPreprocessors"] +__all__ = [ + "MelModelInputPreprocessor", + "BuowMelSpectrogramPreprocessors", + "WaveformInputPreprocessor" +] diff --git a/whoot_model_training/whoot_model_training/preprocessors/augmentations.py b/whoot_model_training/whoot_model_training/preprocessors/augmentations.py new file mode 100644 index 0000000..187b78a --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/augmentations.py @@ -0,0 +1,23 @@ +"""Contains various data augementation techinques for bioacoustics. + +Notes: relies heavily on the audiomentions library + +Basically combine augmentations with ComposeAudioLabel + +For clarity, put augmentations imports here + +For Devs: +To create a new augmentation, create a AudioLabelPreprocessor +""" +from pyha_analyzer.preprocessors.augmentations import ( + ComposeAudioLabel, MixItUp, AudioLabelPreprocessor +) +from audiomentations import Gain, PolarityInversion + +__all__ = [ + "ComposeAudioLabel", + "MixItUp", + "AudioLabelPreprocessor", + "Gain", + "PolarityInversion" +] diff --git a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py index 89208c6..6e049ca 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py +++ b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py @@ -16,6 +16,8 @@ from pyha_analyzer.preprocessors import PreProcessorBase +from .default_preprocessor import DefaultPreprocessor + from .spectrogram_preprocessors import ( BuowMelSpectrogramPreprocessors, SpectrogramParams, @@ -23,6 +25,10 @@ ) from ..models.model import ModelInput +from .waveform_preprocessors import ( + WaveformPreprocessors +) + class SpectrogramModelInPreprocessors(PreProcessorBase): """Defines a preprocessor that after formatting the audio. @@ -32,7 +38,7 @@ class SpectrogramModelInPreprocessors(PreProcessorBase): def __init__( self, - spec_preprocessor: PreProcessorBase, + spec_preprocessor: DefaultPreprocessor, model_input: ModelInput, ): """Wrapper to get the raw spectrogram output of spec_preprocessor. @@ -107,3 +113,57 @@ def __init__( spectrogram_params=spectrogram_params ) super().__init__(spec_preprocessor, model_input) + + +class WaveformInputPreprocessor(SpectrogramModelInPreprocessors): + """Demo of how SpectrogramModelInPreprocessors works. + + Uses a kind of Spectrogram Preprocessor, BuowMelSpectrogramPreprocessors + + This was created in part because legacy implementation of + SpectrogramModelInputPreprocessors had these parameters and subclassed + BuowMelSpectrogramPreprocessors. This class replicates the + format of the old SpectrogramModelInputPreprocessors + class with the new functionality + """ + def __init__( + self, + model_input: ModelInput, + duration=5, + augments: Augmentations = Augmentations(), + ): + """Creates a Online preprocessor for MelSpectrograms Based Models. + + Formats input into spefific ModelInput format. + + Args: + model_input (ModelInput): How the model like input data formatted + duration (int): Length in seconds of input + augments (dict): contains two keys: audio, + spectrogram each defining + a dict of augmentation names and augmentations to run + spectrogram_params (SpectrogramParams): + has the following parameters: + class_list (list): the classes we are + working with one-hot-encoding + dataset_ref (AudioDataset): a + external ref to an AudioDataset + """ + wav_preprocessor = WaveformPreprocessors( + duration=duration, + sr=32_000, + augments=augments, + ) + super().__init__(wav_preprocessor, model_input) + + def __call__(self, batch: dict) -> ModelInput: + """Processes a batch of AudioDataset rows. + + For this specific preprocessor, it creates a spectrogram then + Formats the data as a ModelInput + """ + batch = self.spec_preprocessor(batch) + return self.model_input( + labels=batch["labels"], + waveform=batch["audio"] + ) diff --git a/whoot_model_training/whoot_model_training/preprocessors/default_preprocessor.py b/whoot_model_training/whoot_model_training/preprocessors/default_preprocessor.py new file mode 100644 index 0000000..32f57bd --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/default_preprocessor.py @@ -0,0 +1,101 @@ +"""Defines a default preprocessor class. + +Now this allows for defining a set of common audio loading utilities. +""" +from dataclasses import dataclass +import librosa +import numpy as np +from pyha_analyzer.preprocessors import PreProcessorBase + + +@dataclass +class Augmentations(): + """Dataclass for the augmentations of the model. + + audio (list[dict]): per item key name of augmentation, + value is the augmentation + spectrogram (list[dict]): same idea but augmentations + applied onto spectrograms + """ + audio = None + spectrogram = None + + +class DefaultPreprocessor(PreProcessorBase): + """Default Preprocessor class.""" + def __init__(self, name, duration, sr, *args, **kwargs): + """Initializes the DefaultPreprocessor. + + Args: + name (str): name of preprocessor for logging + duration (float): max length in seconds of audio chunk + sr (int/None): sample rate to standardize audio to + """ + super().__init__(name, *args, **kwargs) + self.duration = duration + self.sr = sr + + def load_audio(self, batch, item_idx): + """Load audio from either array or path. + + Args: + batch (dict): AudioDataset batch + item_idx (int): Processing an item in batch + Returns: + y (np.ndarray): audio array loaded + sr (int): sample rate of audio + """ + try: + if len(batch["audio"][item_idx]["array"]) > 10: + y = batch["audio"][item_idx]["array"] + sr = batch["audio"][item_idx]["sampling_rate"] + else: + if librosa.get_duration( + path=batch["audio"][item_idx]["path"] + ) > 2 * 60: + raise IOError("File too long to process") + + y, sr = librosa.load( + path=batch["audio"][item_idx]["path"], + sr=self.sr + ) + + except IOError as e: + y = np.zeros(self.sr * 5) + sr = self.sr + print("File Likely is corrupted, moving on", e) + raise IOError from e + + return y, sr + + def augment_audio( + self, + y: np.ndarray, + sr: int, + start: float, + label: str, + augments: Augmentations + ): + """Placeholder for audio augmentations. + + Args: + y: audio array + sr: sample rate + label: label associated with audio + start: starting point in seconds to crop audio + augments: augmentations to apply + """ + # Handle out of bound issues + end_sr = int(start * sr) + int(sr * self.duration) + if y.shape[-1] <= end_sr: + y = np.pad(y, end_sr - y.shape[-1]) + + # Audio Based Augmentations + if augments.audio is not None: + y, label = augments.audio(y, sr, label) + + new_y = y[int(start * sr):end_sr] + if new_y.shape[-1] < int(sr * self.duration): + raise IOError("Audio too short after augmentation") + + return new_y, label diff --git a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py index 4d7e786..db62310 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py @@ -2,14 +2,13 @@ Pulled from pyha_analyzer/preprocessors/spectogram_preprocessors.py """ - from dataclasses import dataclass import librosa import numpy as np from torchvision import transforms -from pyha_analyzer.preprocessors import PreProcessorBase +from .default_preprocessor import DefaultPreprocessor, Augmentations @dataclass @@ -28,21 +27,7 @@ class SpectrogramParams: n_mels: int = 256 -@dataclass -class Augmentations: - """Dataclass for the augmentations of the model. - - audio (list[dict]): per item key name of augmentation, - value is the augmentation - spectrogram (list[dict]): same idea but augmentations - applied onto spectrograms - """ - - audio = None - spectrogram = None - - -class BuowMelSpectrogramPreprocessors(PreProcessorBase): +class BuowMelSpectrogramPreprocessors(DefaultPreprocessor): """Preprocessor for processing audio into spectrograms. Particularly for the buow dataset @@ -72,42 +57,38 @@ def __init__( self.n_mels = spectrogram_params.n_mels self.spectrogram_params = spectrogram_params - super().__init__(name="MelSpectrogramPreprocessor") + super().__init__( + name="MelSpectrogramPreprocessor", duration=duration, sr=self.sr + ) def __call__(self, batch): """Process a batch of data from an AudioDataset.""" + # pylint: disable=duplicate-code new_audio = [] new_labels = [] for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] - y, sr = ( - batch["audio"][item_idx]["array"], - batch["audio"][item_idx]["sampling_rate"], - ) + y, sr = self.load_audio(batch, item_idx) start = 0 - # Handle out of bound issues - end_sr = int(start * sr) + int(sr * self.duration) - if y.shape[-1] <= end_sr: - y = np.pad(y, end_sr - y.shape[-1]) - - # Audio Based Augmentations - if self.augments.audio is not None: - y, label = self.augments.audio(y, sr, label) + y, label = self.augment_audio(y, sr, start, label, self.augments) pillow_transforms = transforms.ToPILImage() + spec = librosa.feature.melspectrogram( + y=y, + sr=sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + power=self.power, + n_mels=self.n_mels, + ) + pcen_s = librosa.pcen(spec * (2**31)) + mels = ( np.array( pillow_transforms( - librosa.feature.melspectrogram( - y=y[int(start * sr):end_sr], - sr=sr, - n_fft=self.n_fft, - hop_length=self.hop_length, - power=self.power, - n_mels=self.n_mels, - ) + pcen_s ), np.float32, )[np.newaxis, ::] @@ -120,7 +101,7 @@ def __call__(self, batch): new_audio.append(mels) new_labels.append(label) - batch["audio"] = new_audio + batch["audio"] = np.concatenate(new_audio) batch["labels"] = np.array(new_labels, dtype=np.float32) return batch @@ -155,8 +136,7 @@ class PCENMelSpectrogramPreprocessors(BuowMelSpectrogramPreprocessors): def __call__(self, batch): """Process a batch of data from an AudioDataset.""" - new_audio = [] - new_labels = [] + new_audio, new_labels = [], [] for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] y, sr = ( @@ -164,20 +144,12 @@ def __call__(self, batch): batch["audio"][item_idx]["sampling_rate"], ) start = 0 - - # Handle out of bound issues - end_sr = int(start * sr) + int(sr * self.duration) - if y.shape[-1] <= end_sr: - y = np.pad(y, end_sr - y.shape[-1]) - - # Audio Based Augmentations - if self.augments.audio is not None: - y, label = self.augments.audio(y, sr, label) + y, label = self.augment_audio(y, sr, start, label, self.augments) pillow_transforms = transforms.ToPILImage() spec = librosa.feature.melspectrogram( - y=y[int(start * sr):end_sr], + y=y, sr=sr, n_fft=self.n_fft, hop_length=self.hop_length, diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py new file mode 100644 index 0000000..4a5f5af --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -0,0 +1,105 @@ +"""Defines preprocessors for creating spectrograms. + +Pulled from pyha_analyzer/preprocessors/spectogram_preprocessors.py +""" + +import numpy as np +from .default_preprocessor import DefaultPreprocessor, Augmentations + +# @dataclass +# class WaveformParams: +# """Dataclass for spectrogram Parameters. + +# n_fft: (int) number of fft bins +# hop_length (int) skip count +# power: (float) usually 2 +# n_mels: (int) number of mel bins +# """ +# n_fft: int = 2048 +# hop_length: int = 256 +# power: float = 2.0 +# n_mels: int = 256 + + +class WaveformPreprocessors(DefaultPreprocessor): + """Preprocessor for processing audio into spectrograms. + + Particularly for the buow dataset + """ + + def __init__( + self, + duration=5, + sr=None, + augments: Augmentations = Augmentations(), + ): + """Defines a BuowMelSpectrogramPreprocessors. + + Args: + duration (float): length of chunk of data to train on + augments (Augmentations): An augmentation to apply to waveforms + sr (int/None): sample rate of audio to standize, + defaults to use file sr + spectrogram_params (SpectrogramParams): + config for spectrogram generation + """ + self.duration = duration + self.augments = augments + self.sr = sr + + # # Below parameter defaults from + # # https://arxiv.org/pdf/2403.10380 pg 25 + # self.n_fft = spectrogram_params.n_fft + # self.hop_length = spectrogram_params.hop_length + # self.power = spectrogram_params.power + # self.n_mels = spectrogram_params.n_mels + # self.spectrogram_params = spectrogram_params + + super().__init__( + name="MelSpectrogramPreprocessor", + duration=duration, + sr=self.sr) + + def __call__(self, batch): + """Process a batch of data from an AudioDataset.""" + new_audio = [] + new_labels = [] + for item_idx in range(len(batch["audio"])): + label = batch["labels"][item_idx] + + y, sr = self.load_audio(batch, item_idx) + + start = np.random.uniform(0, len(y)/sr - self.duration) + + y, label = self.augment_audio(y, sr, start, label, self.augments) + + new_audio.append(y) + new_labels.append(label) + + batch["audio"] = new_audio + batch["labels"] = np.array(new_labels, dtype=np.float32) + # print(len(batch["audio"]), len(batch["labels"])) + + return batch + + def get_augmentations(self): + """Returns a list of augmentations. + + Perhaps for logging purposes + + Returns: + (list) all the augmentations + """ + return self.augments + + def __repr__(self): + """Use representation to describe the augmentations. + + Returns: + (str) all information about this preprocessor + """ + return ( + f"""{self.name} + Augmentations: {self.augments} + """ + ) diff --git a/whoot_model_training/whoot_model_training/trainer.py b/whoot_model_training/whoot_model_training/trainer.py index 8e1a469..a83179f 100644 --- a/whoot_model_training/whoot_model_training/trainer.py +++ b/whoot_model_training/whoot_model_training/trainer.py @@ -14,6 +14,8 @@ import torch from tqdm import tqdm +import datasets + from pyha_analyzer import PyhaTrainingArguments from pyha_analyzer import PyhaTrainer @@ -94,7 +96,7 @@ def __init__( logger (CometMLLoggerSupplement): Class that adds additional logging On top of logging done by PyhaTrainer - preprocessor (PreProcessorBase): + preprocessor (DefaultPreprocessor): Preprocessor used for formatting the data """ metrics = WhootMutliClassMetrics(dataset.get_class_labels().names) @@ -114,31 +116,43 @@ def __init__( ) def predict( - self, - test_dataset: AudioDataset, - ignore_keys=None, # overloaded, please ignore - metric_key_prefix: str = "test", # overloaded, please ignore - ): - """Run Inferance with trained model! - - Args: + self, test_dataset: AudioDataset, - an AudioDataset to collect predictions from - ignore_keys: legacy - metric_key_prefix: legacy + ignore_keys=None, + metric_key_prefix: str = "test", + save_path=""): + """Run Inferance on a given dataset. - ignore_keys, and metric_key_prefix exist for subclass overriding + Allows for getting predicted outputs to label a new dataset + Args: + test_dataset (AudioDataset): dataset to get preds from + This has labels but they are meaningless in this method + ignore_keys: N/A + metric_key_prefix: str = "test" + Returns: test_dataset with a new col: "pred" """ + # test_dataset = test_dataset.select(range(100)) test_dataloader = self.get_test_dataloader(test_dataset) preds = [] + data_selected = [] + count = 0 for batch in tqdm(test_dataloader): - preds.append( - self.model( - self.model.input_format(**batch) - )["logits"].detach().cpu() - ) - - dataset = test_dataset.to_dict() + pred = self.model( + self.model.input_format(**batch) + )["logits"].detach().cpu().half() + preds.append(pred) + data_selected.extend(range(count, count + len(pred))) + count += len(pred) + + if count % 101 == 0: + dataset = test_dataset.with_format() + out = dataset.select(data_selected).to_dict() + out["pred"] = torch.concat(preds).detach().numpy() + # saves as a directory + datasets.Dataset.from_dict(out).save_to_disk(save_path) + + dataset = test_dataset.with_format() + dataset = dataset.select(data_selected).to_dict() dataset["pred"] = torch.concat(preds).detach().numpy() return dataset