From a840ee5f6fb7a6d3e70f9b23913e3ea8629465d3 Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 10 Oct 2025 15:07:35 -0700 Subject: [PATCH 01/18] Adds working code for XC data and few shot model training --- data_downloader/xc.py | 76 ++++++++ data_downloader/xc_aux_downloader.py | 100 +++++++++++ pyproject.toml | 11 +- .../data_extractor/Jacuzzi_Olden_extractor.py | 62 +++++++ .../data_extractor/__init__.py | 21 ++- .../data_extractor/xc_extractor.py | 169 ++++++++++++++++++ .../whoot_model_training/models/__init__.py | 6 +- .../models/few_shot_model.py | 112 ++++++++++++ .../preprocessors/__init__.py | 6 +- .../preprocessors/augmentations.py | 13 ++ .../preprocessors/base_preprocessor.py | 62 +++++++ .../preprocessors/waveform_preprocessors.py | 130 ++++++++++++++ 12 files changed, 763 insertions(+), 5 deletions(-) create mode 100644 data_downloader/xc.py create mode 100644 data_downloader/xc_aux_downloader.py create mode 100644 whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py create mode 100644 whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py create mode 100644 whoot_model_training/whoot_model_training/models/few_shot_model.py create mode 100644 whoot_model_training/whoot_model_training/preprocessors/augmentations.py create mode 100644 whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py diff --git a/data_downloader/xc.py b/data_downloader/xc.py new file mode 100644 index 0000000..a11d3c9 --- /dev/null +++ b/data_downloader/xc.py @@ -0,0 +1,76 @@ +import requests +import os +import json +import urllib.parse + +class XenoCantoDownloader(): + def __init__(self, api_key=None): + self.endpoint_url = "https://xeno-canto.org/api/3/recordings" + self.api_key = os.environ["XC_API_KEY"] if api_key is None else api_key + assert self.api_key is not None, "API KEY MISSING: Put API key in Enviroment Var!" + + def __call__(self, + query = None, + loc=None, + box=None): + + if query is None: + query = self.build_query( + loc=loc, + box=None, + ) + + page_datas = [] + page_data = self.get_page(query, page=1) + page_datas.append(page_data) + + # Get rest of data! + for i in range(2, page_data["numPages"] + 1): + page_data = self.get_page(query, page=i) + page_datas.append(page_data) + + return page_datas + + def concat_recording_data(self, page_datas): + new_page_data = [] + for page_data in page_datas: + new_page_data = new_page_data + page_data["recordings"] + return new_page_data + + def build_query( + self, + loc="San Diego, California, United States of America", + box=None, + ): + search_tags = "" + if loc is not None: + search_tags += f"loc:\"{loc}\"+" + return search_tags[:-1] #remove last + + + def get_page(self, query, page=1): + res = requests.get(self.endpoint_url + "?"+ urllib.parse.urlencode({ + "query": query, + "key": self.api_key, + "page": page + })) + if res.status_code == 200: + return json.loads(res.text) + else: + {} + + # def download_files(self, data): + # if type(data) == dict: + # data = self.concat_recording_data(self, data) + + # for recording in data: + # requests + +if __name__ == "__main__": + # parser = argparse.ArgumentParser( + # description='Input Directory Path' + # ) + # parser.add_argument('meta', type=str, + # help='Path to metadata csv') + # args = parser.parse_args() + xcd = XenoCantoDownloader() + print(xcd()) \ No newline at end of file diff --git a/data_downloader/xc_aux_downloader.py b/data_downloader/xc_aux_downloader.py new file mode 100644 index 0000000..ca1fbac --- /dev/null +++ b/data_downloader/xc_aux_downloader.py @@ -0,0 +1,100 @@ +# %% +from xc import XenoCantoDownloader +from dotenv import load_dotenv +import os + +# Load environment variables from the .env file +load_dotenv() + +xcd = XenoCantoDownloader(api_key=os.environ["XC_API_KEY"]) + + + +# %% +import json + +with open("data/xc_meta.json", mode="r") as f: + data = json.load(f) + +species = { recording["en"] for page in data for recording in page["recordings"] } + +# %% +len({recording["en"] for page in data for recording in page["recordings"] }) + +# %% +len(species) + +# %% +data = [] +import tqdm +for specie in tqdm.tqdm(list(species)): + data.append(xcd(query=f'en:"{specie}"')) + +# %% +import itertools +data = list(itertools.chain.from_iterable(data)) + +# %% +with open("xc_meta_aux.json", mode="w") as f: + json.dump(data, f, indent=4) + +# %% +import requests + +# %% +import shutil +import os +from pathlib import Path +from multiprocessing.pool import ThreadPool + +# https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests +def download_file(url, local_filename, dry_run=False): + if os.path.exists(local_filename): + return local_filename + + try: + with requests.get(url, stream=True) as r: + with open(local_filename, 'wb') as f: + if not dry_run: + shutil.copyfileobj(r.raw, f) + else: + print(local_filename) + + return local_filename + except IOError as e: + print(e, flush=True) + return None + +def download_files(xcd, data, parent_folder="data/xeno-canto_aux", workers = 4): + def prep_download(args): + url = args[0] + file_path = args[1] + return download_file(url, file_path) + + os.makedirs(parent_folder, exist_ok=True) + + if "recordings" in data[0]: + data = xcd.concat_recording_data(data) + download_data = [ + (recording["file"], Path(parent_folder) / Path(recording["file-name"])) + for recording in data + ] + pool = ThreadPool(workers) + results = pool.imap_unordered(prep_download, download_data) + pool.close() + return results + +results = download_files(xcd, data) +results + +# %% +import pandas as pd +recordings = xcd.concat_recording_data(data) +df = pd.DataFrame(recordings) + +df.shape + +# %% + + + diff --git a/pyproject.toml b/pyproject.toml index 0f44af2..25fd184 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "numba==0.61.0", "pandas>=2.3.0", "pydub>=0.25.1", + "python-dotenv>=1.1.1", "pyyaml>=6.0.2", "scikit-learn>=1.7.0", "tqdm>=4.67.1", @@ -40,9 +41,17 @@ model-training = [ "comet-ml>=3.43.2", ] +perch = [ + "perch-hoplite>=0.1.0", + "tensorflow-hub>=0.16.1", + "tensorflow[and-cuda]>=2.20.0", +] + notebooks = [ "ipykernel>=6.29.5", "ipywidgets>=8.1.6", + "matplotlib>=3.10.6", + "seaborn>=0.13.2", ] @@ -50,7 +59,7 @@ notebooks = [ cu128 = "https://download.pytorch.org/whl/cu128" [tool.setuptools] -packages = ["make_model", "assess_birdnet", "whoot_model_training"] +packages = ["make_model", "assess_birdnet", "whoot_model_training", "data_downloader"] [tool.uv.sources] pyha-analyzer = { git = "https://github.com/UCSD-E4E/pyha-analyzer-2.0.git", branch = "support_whoot" } diff --git a/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py new file mode 100644 index 0000000..d79fea3 --- /dev/null +++ b/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py @@ -0,0 +1,62 @@ +"""Ceates Dataset from the Xeno-Canto Data Downlaoder tool. + +See data_downloader/xc.py +""" + +import os +from dataclasses import dataclass + +import numpy as np +from datasets import ( + load_dataset, + Dataset, + Audio, + DatasetDict, + ClassLabel, + Sequence, +) +from ..dataset import AudioDataset + +import json +import pandas as pd + + +def one_hot_encode(row: dict, classes: list): + """One hot Encodes a list of labels. + + Args: + row (dict): row of data in a dataset containing a labels column + classes: a list of classes + """ + one_hot = np.zeroes(len(classes)) + one_hot[row["labels"]] = 1 + row["labels"] = np.array(one_hot, dtype=float) + return row + +def Jacuzzi_Olden_Extractor( + root_path +): + audio_path = f"{root_path}/training/audio" + train_df = pd.read_csv(f"{root_path}/training/training_data_annotations.csv") + train_df["labels"] = train_df["labels"].str.split(",") + train_df["file_path"] = train_df["audio_subdir"].apply( + lambda folder: f"{audio_path}/{folder}/" + ) + train_df["file"].apply(lambda path: path + ".wav") + + test_df = pd.read_csv(f"{root_path}/test/test_data_annotations.csv") + test_df["labels"] = test_df["labels"].str.split(",") + test_df["file"] = test_df["file"].str.findall( + r"-0.\d+_([\w.]+).wav").apply(lambda x: x[0]) + test_df["file_path"] = test_df["focal_class"].apply( + lambda folder: f"{audio_path}/{folder}/" + ) + test_df["file"].apply(lambda path: path + ".wav") + + return train_df, test_df + + # TODO + # Convert to AudioDataset + # Convert Labels to right format + # Convert audio type + # Done + + diff --git a/whoot_model_training/whoot_model_training/data_extractor/__init__.py b/whoot_model_training/whoot_model_training/data_extractor/__init__.py index 5e0ffe7..7b5e158 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/__init__.py +++ b/whoot_model_training/whoot_model_training/data_extractor/__init__.py @@ -8,5 +8,24 @@ buowset_binary_extractor, ) from .esc50_extractor import esc50_extractor +from .Jacuzzi_Olden_extractor import Jacuzzi_Olden_Extractor +from .xc_extractor import xc_extractor -__all__ = ["buowset_extractor", "buowset_binary_extractor", "esc50_extractor"] +__all__ = ["buowset_extractor", "buowset_binary_extractor", "esc50_extractor", "Jacuzzi_Olden_Extractor", "xc_extractor"] + +def concat_dataset(datasetA, datasetB): + for split in datasetA.keys(): + pass + + #TODO FIGURE OUT HOW TO SAFETLY COMBINE TWO DATASETS + + # labels + # this is tricky, you need to check class names for union, then + # Apply annotations accordingly + # maybe use a dict to handle classes in both datasets + + # Audio + # should be able to merge + + # Metadata + # Consider dropping all non-required columns, will make merge easier \ No newline at end of file diff --git a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py new file mode 100644 index 0000000..17b63e8 --- /dev/null +++ b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py @@ -0,0 +1,169 @@ +"""Ceates Dataset from the Xeno-Canto Data Downlaoder tool. + +See data_downloader/xc.py +""" + +import os +import shutil +from pathlib import Path +from dataclasses import dataclass +from collections import Counter +from pydub import AudioSegment + + +import numpy as np +from datasets import ( + load_dataset, + Dataset, + Audio, + DatasetDict, + ClassLabel, + Sequence, +) +from ..dataset import AudioDataset + +import json +import librosa + +def filter_by_count(ds, col="en", threshold=10): + count_by_species = Counter(ds[col]) + return ds.filter(lambda row: count_by_species[row] > threshold, input_columns=[col]) + + +def filter_xc_data(row: dict): + """ In personal experience, raw XC data is very messy + Some files get coruptted + This intention checks to see if loading files is possible for the frist place + """ + + file_path = row["filepath"] + try: + # Heuristic, if we can load 3 seconds, file is probably okay + # Prevents some files from taking forever + librosa.load(path=file_path, duration=3) + return True + except Exception as e: + print(e, file_path) + return False + + +def one_hot_encode(row: dict, classes: list): + """One hot Encodes a list of labels. + + Args: + row (dict): row of data in a dataset containing a labels column + classes: a list of classes + """ + one_hot = np.zeros(len(classes)) + one_hot[row["labels"]] = 1 + row["labels"] = np.array(one_hot, dtype=float) + return row + +def convert_audio_to_flac(row, error_path="bad_files", col="audio"): + file_path = row[col] + flac_path = Path(file_path).parent / (Path(file_path).stem + ".flac") + if os.path.exists(flac_path): + row[col] = str(flac_path) + if os.path.exists(file_path): + os.remove(file_path) # Remove origional file, we don't need it + return row + try: + wav_audio = AudioSegment.from_file(file_path) + wav_audio.export(flac_path, format="flac") + except Exception as e: + if os.path.exists(file_path): + os.makedirs(error_path, exist_ok=True) + shutil.move(file_path, error_path) + # If quit halfway through processing, make sure we get rid of the bad file + # if os.path.exists(flac_path): + # os.remove(flac_path) + print("ERROR", "move to", os.path.join(error_path, Path(file_path).name), "ERR MSG:", e) + row[col] = str(os.path.join(error_path, Path(file_path).name)) + return row + row[col] = str(flac_path) + return row + +@dataclass +class XCParams(): + """Parameters that describe ESC-50. + + validation_fold (int): label for valid split + test_fold (int): label for valid split + sample_rate (int): sample rate of the data + filepath (string): name of column in csv for filepaths + """ + validation_fold = 4 + test_fold = 5 + sample_rate = 44_100 + +def xc_extractor( + XC_dataset_json_path, + parent_path, + params: XCParams = XCParams(), + bad_file_path="data/xc_bad_file" +): + + with open(XC_dataset_json_path, mode="r") as f: + xc_recordings_paged = json.load(f) + + xc_recordings = [] + for page in xc_recordings_paged: + xc_recordings.extend(page["recordings"]) + + dataset = Dataset.from_list(xc_recordings) + + dataset = dataset.add_column("labels", dataset["en"]) + dataset = dataset.class_encode_column("labels") + class_list = dataset.features["labels"].names + multilabel_class_label = Sequence(ClassLabel(names=class_list)) + dataset = dataset.map( + lambda row: one_hot_encode(row, class_list) + ).cast_column( + "labels", + multilabel_class_label + ) + + dataset = dataset.add_column( + "audio", [ + os.path.join(parent_path, file.replace("/", "_")) for file in dataset["file-name"] + ] + ) + + # Fix file paths + dataset = dataset.map(convert_audio_to_flac, fn_kwargs={"error_path": bad_file_path}, num_proc=16) + dataset = dataset.filter(lambda x: not bad_file_path in x["audio"], num_proc=16) + dataset = dataset.add_column("filepath", dataset["audio"]) + + + dataset = dataset.cast_column("audio", Audio(sampling_rate=params.sample_rate)) + + # TODO FIGURE OUT HOW TO DO SPLITS! + # # Create splits of the data + # test_ds = dataset.filter(lambda x: x["fold"] == params.test_fold) + # valid_ds = dataset.filter(lambda x: x["fold"] == params.validation_fold) + # train_ds = dataset.filter( + # lambda x: ( + # x["fold"] != params.test_fold + # and x["fold"] != params.validation_fold + # ) + # ) + + dataset = dataset.cast_column( + "en", ClassLabel(names=list(set(dataset["en"]))) + ) + + dataset = filter_by_count(dataset) + + train_test = dataset.train_test_split(0.2, stratify_by_column="en") + test_val = train_test["test"].train_test_split(0.2, stratify_by_column="en") + + dataset = AudioDataset( + DatasetDict({ + "train": train_test["train"], + "valid": test_val["train"], + "test": test_val["test"]}) + ) + + # dataset.save_to_disk(output_path) + + return dataset diff --git a/whoot_model_training/whoot_model_training/models/__init__.py b/whoot_model_training/whoot_model_training/models/__init__.py index 3c539ff..3c46187 100644 --- a/whoot_model_training/whoot_model_training/models/__init__.py +++ b/whoot_model_training/whoot_model_training/models/__init__.py @@ -6,6 +6,7 @@ from .timm_model import TimmModel, TimmInputs, TimmModelConfig from .model import Model, ModelInput, ModelOutput +from .few_shot_model import PerchEmbeddingInput, PerchFewShotModel, FewShotModelConfig __all__ = [ "TimmModel", @@ -13,5 +14,8 @@ "TimmModelConfig", "Model", "ModelInput", - "ModelOutput" + "ModelOutput", + "PerchEmbeddingInput", + "PerchFewShotModel", + "FewShotModelConfig" ] diff --git a/whoot_model_training/whoot_model_training/models/few_shot_model.py b/whoot_model_training/whoot_model_training/models/few_shot_model.py new file mode 100644 index 0000000..a93d2fc --- /dev/null +++ b/whoot_model_training/whoot_model_training/models/few_shot_model.py @@ -0,0 +1,112 @@ +"""Build a few_shot_learning classifier. + +Inspired by the work of +Jacuzzi, G., Olden, J.D., 2025. Few-shot transfer learning enables robust acoustic +monitoring of wildlife communities at the landscape scale. +Ecological Informatics 90, 103294. +doi.org/10.1016/j.ecoinf.2025.103294 + +These models convert thier input into an embedding from a large audio model and +do processing on top of that embedding +""" + +from .model import ModelInput, ModelOutput +from torch import nn, Tensor +from perch_hoplite.zoo import model_configs +from .model import Model, ModelInput, ModelOutput, has_required_inputs +from transformers import PretrainedConfig + +## Common Classes + +class EmbeddingModel(): + def embed(self): + raise NotImplementedError() + +class EmbeddingInput(ModelInput): + model = EmbeddingModel() + embedding_size = 0 + + def __init__(self, + labels, + waveform = None, + spectrogram = None): + super().__init__(labels, waveform, spectrogram) + + self["embedding"] = self.model.embed(waveform) + +## Unique Models + +class PerchEmbeddings(EmbeddingModel): + model = model_configs.load_model_by_name('perch_8') + def embed(self, waveforms): + # embeddings = [ + # self.model.embed(waveform).embeddings[0] + # for waveform in waveforms + # ] + return waveforms + +class PerchEmbeddingInput(EmbeddingInput): + model = PerchEmbeddings() + embedding_size = 1280 + + +class FewShotModelConfig(PretrainedConfig): + """Config for Timm Model Zoo Models!""" + def __init__( + self, + num_classes=200, + **kwargs + ): + """Creates Config. + + Args: + + """ + self.num_classes = num_classes + super().__init__(**kwargs) + +class PerchFewShotModel(Model, nn.Module): + def __init__( + self, + config: FewShotModelConfig + ): + """Init for TimmModel. + + kwargs: + timm_model (str): name of model backbone from timms to use, + Default: "resnet34" + pretrained (bool): use a pretrained model from timms, Default: True + in_chans (int): number of channels of audio: Default: 1 + num_classes (int): number of classes in the dataset: Default 6 + loss (any): custom loss function Default: BCEWithLogitsLoss + """ + super().__init__() + + self.input_format = PerchEmbeddingInput + self.output_format = ModelOutput + + self.config = config + assert config.num_classes > 0 + + # TODO BUILD MLP + self.linear = nn.Linear(self.input_format.embedding_size, config.num_classes) + + # TODO USE CUSTOM LOSS FOR FEW SHOW LEARNING + self.loss = nn.BCEWithLogitsLoss() + + @has_required_inputs() + def forward(self, x: PerchEmbeddingInput): + # Use perch to create embeddings + embeddings = Tensor(x.model.model.embed(x["waveform"].cpu()).embeddings).to(x["waveform"].device) + + logits = self.linear(embeddings).squeeze(1) + loss = self.loss(logits, x["labels"]) + + return ModelOutput( + logits=logits, + embeddings=embeddings, + loss=loss, + labels=x["labels"] + ) + + diff --git a/whoot_model_training/whoot_model_training/preprocessors/__init__.py b/whoot_model_training/whoot_model_training/preprocessors/__init__.py index 8efacfb..df579e5 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/__init__.py +++ b/whoot_model_training/whoot_model_training/preprocessors/__init__.py @@ -8,7 +8,8 @@ """ from .base_preprocessor import ( - MelModelInputPreprocessor + MelModelInputPreprocessor, + WaveformInputPreprocessor ) from .spectrogram_preprocessors import ( BuowMelSpectrogramPreprocessors @@ -16,5 +17,6 @@ __all__ = [ "MelModelInputPreprocessor", - "BuowMelSpectrogramPreprocessors" + "BuowMelSpectrogramPreprocessors", + "WaveformInputPreprocessor" ] diff --git a/whoot_model_training/whoot_model_training/preprocessors/augmentations.py b/whoot_model_training/whoot_model_training/preprocessors/augmentations.py new file mode 100644 index 0000000..2c457cd --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/augmentations.py @@ -0,0 +1,13 @@ +"""Contains various data augementation techinques for bioacoustics +Notes: relies heavily on the audiomentions library + +Basically combine augmentations with ComposeAudioLabel + +For clarity, put augmentations imports here + +For Devs: +To create a new augmentation, create a AudioLabelPreprocessor +""" +from pyha_analyzer.preprocessors.augmentations import ComposeAudioLabel, MixItUp, AudioLabelPreprocessor +from audiomentations import Gain, PolarityInversion + diff --git a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py index a7ad953..dc63ef4 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py +++ b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py @@ -23,6 +23,10 @@ ) from ..models.model import ModelInput +from .waveform_preprocessors import ( + WaveformPreprocessors +) + class SpectrogramModelInPreprocessors(PreProcessorBase): """Defines a preprocessor that after formatting the audio. @@ -105,3 +109,61 @@ def __init__( spectrogram_params=spectrogram_params ) super().__init__(spec_preprocessor, model_input) + + +class WaveformInputPreprocessor(SpectrogramModelInPreprocessors): + """Demo of how SpectrogramModelInPreprocessors works. + + Uses a kind of Spectrogram Preprocessor, BuowMelSpectrogramPreprocessors + + This was created in part because legacy implementation of + SpectrogramModelInputPreprocessors had these parameters and subclassed + BuowMelSpectrogramPreprocessors. This class replicates the + format of the old SpectrogramModelInputPreprocessors + class with the new functionality + """ + def __init__( + self, + model_input: ModelInput, + duration=5, + augments: Augmentations = Augmentations(), + ): + """Creates a Online preprocessor for MelSpectrograms Based Models. + + Formats input into spefific ModelInput format. + + Args: + model_input (ModelInput): How the model like input data formatted + duration (int): Length in seconds of input + augments (dict): contains two keys: audio, + spectrogram each defining + a dict of augmentation names and augmentations to run + spectrogram_params (SpectrogramParams): + has the following parameters: + class_list (list): the classes we are + working with one-hot-encoding + dataset_ref (AudioDataset): a + external ref to an AudioDataset + """ + wav_preprocessor = WaveformPreprocessors( + duration=duration, + augments=augments, + ) + super().__init__(wav_preprocessor, model_input) + + def __call__(self, batch: dict) -> ModelInput: + """Processes a batch of AudioDataset rows. + + For this specific preprocessor, it creates a spectrogram then + Formats the data as a ModelInput + """ + batch = self.spec_preprocessor(batch) + return self.model_input( + labels=batch["labels"], + waveform=batch["audio"] + ) + + + + + diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py new file mode 100644 index 0000000..5651e10 --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -0,0 +1,130 @@ +"""Defines preprocessors for creating spectrograms. + +Pulled from pyha_analyzer/preprocessors/spectogram_preprocessors.py +""" +from dataclasses import dataclass + +import librosa +import numpy as np +from torchvision import transforms + +from pyha_analyzer.preprocessors import PreProcessorBase + + +# @dataclass +# class WaveformParams: +# """Dataclass for spectrogram Parameters. + +# n_fft: (int) number of fft bins +# hop_length (int) skip count +# power: (float) usually 2 +# n_mels: (int) number of mel bins +# """ +# n_fft: int = 2048 +# hop_length: int = 256 +# power: float = 2.0 +# n_mels: int = 256 + + +@dataclass +class Augmentations(): + """Dataclass for the augmentations of the model. + + audio (list[dict]): per item key name of augmentation, + value is the augmentation + spectrogram (list[dict]): same idea but augmentations + applied onto spectrograms + """ + audio = None + spectrogram = None + + +class WaveformPreprocessors(PreProcessorBase): + """Preprocessor for processing audio into spectrograms. + + Particularly for the buow dataset + """ + + def __init__( + self, + duration=5, + augments: Augmentations = Augmentations(), + ): + """Defines a BuowMelSpectrogramPreprocessors. + + Args: + duration (float): length of chunk of data to train on + augments (Augmentations): An augmentation to apply to waveforms + spectrogram_params (SpectrogramParams): + config for spectrogram generation + """ + self.duration = duration + self.augments = augments + + # # Below parameter defaults from https://arxiv.org/pdf/2403.10380 pg 25 + # self.n_fft = spectrogram_params.n_fft + # self.hop_length = spectrogram_params.hop_length + # self.power = spectrogram_params.power + # self.n_mels = spectrogram_params.n_mels + # self.spectrogram_params = spectrogram_params + + super().__init__(name="MelSpectrogramPreprocessor") + + def __call__(self, batch): + """Process a batch of data from an AudioDataset.""" + new_audio = [] + new_labels = [] + for item_idx in range(len(batch["audio"])): + label = batch["labels"][item_idx] + try: + y, sr = librosa.load(path=batch["audio"][item_idx]["path"]) + except Exception as e: + print(e) + print("File Likely is corrupted, moving on") + continue + + start = np.random.uniform(0, len(y)/sr - self.duration) + + # Handle out of bound issues + end_sr = int(start * sr) + int(sr * self.duration) + if y.shape[-1] <= end_sr: + y = np.pad(y, end_sr - y.shape[-1]) + + # Audio Based Augmentations + if self.augments.audio is not None: + y, label = self.augments.audio(y, sr, label) + + new_y = y[int(start * sr):end_sr] + if (new_y.shape[-1] < int(sr * self.duration)): + continue + + new_audio.append(new_y) + new_labels.append(label) + + batch["audio"] = new_audio + batch["labels"] = np.array(new_labels, dtype=np.float32) + + return batch + + def get_augmentations(self): + """Returns a list of augmentations. + + Perhaps for logging purposes + + Returns: + (list) all the augmentations + """ + return self.augments + + def __repr__(self): + """Use representation to describe the augmentations. + + Returns: + (str) all information about this preprocessor + """ + return ( + f"""{self.name} + Augmentations: {self.augments} + MelSpectrogram: {self.spectrogram_params} + """ + ) From add3efc3421d84ed3081635201e53415190f061c Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 10 Oct 2025 15:08:37 -0700 Subject: [PATCH 02/18] Add soundfile --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 25fd184..aab0492 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "python-dotenv>=1.1.1", "pyyaml>=6.0.2", "scikit-learn>=1.7.0", + "soundfile>=0.13.1", "tqdm>=4.67.1", ] From a27c56d51c5324b0829937349ba49ddb9036975a Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 10 Oct 2025 15:10:10 -0700 Subject: [PATCH 03/18] Add code for using XC api --- data_downloader/downloader_demo.ipynb | 238 ++++++++++++++++++++ data_downloader/get_more_species_data.ipynb | 184 +++++++++++++++ 2 files changed, 422 insertions(+) create mode 100644 data_downloader/downloader_demo.ipynb create mode 100644 data_downloader/get_more_species_data.ipynb diff --git a/data_downloader/downloader_demo.ipynb b/data_downloader/downloader_demo.ipynb new file mode 100644 index 0000000..f90fa04 --- /dev/null +++ b/data_downloader/downloader_demo.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3842a3a9", + "metadata": {}, + "outputs": [], + "source": [ + "from xc import XenoCantoDownloader\n", + "from dotenv import load_dotenv\n", + "import os\n", + "\n", + "# Load environment variables from the .env file\n", + "load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa77bd28", + "metadata": {}, + "outputs": [], + "source": [ + "xcd = XenoCantoDownloader(api_key=os.environ[\"XC_API_KEY\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ee1304b", + "metadata": {}, + "outputs": [], + "source": [ + "query = xcd.build_query()\n", + "res = xcd.get_page(query)\n", + "res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fed0106", + "metadata": {}, + "outputs": [], + "source": [ + "data = xcd(query=\"box:32.485,-117.582,33.482,-115.228\")\n", + "d\n", + "#box:32.485,-117.582,33.482,-115.228" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a43fcef6", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import requests\n", + "with open(\"xc_meta.json\", mode=\"w\") as f:\n", + " json.dump(data, f, indent=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0942a027", + "metadata": {}, + "outputs": [], + "source": [ + "req = requests.get(data[0][\"recordings\"][0][\"file\"])\n", + "req" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d62a7a6", + "metadata": {}, + "outputs": [], + "source": [ + "data[0][\"recordings\"][0][\"file\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcc63419", + "metadata": {}, + "outputs": [], + "source": [ + "req.content" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30af4509", + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "import os\n", + "from pathlib import Path\n", + "from multiprocessing.pool import ThreadPool\n", + "\n", + "# https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests\n", + "def download_file(url, local_filename, dry_run=False):\n", + " if os.path.exists(local_filename):\n", + " return local_filename\n", + "\n", + " try:\n", + " with requests.get(url, stream=True) as r:\n", + " with open(local_filename, 'wb') as f:\n", + " if not dry_run:\n", + " shutil.copyfileobj(r.raw, f)\n", + " else:\n", + " print(local_filename)\n", + "\n", + " return local_filename\n", + " except IOError as e:\n", + " print(e, flush=True)\n", + " return None\n", + "\n", + "def download_files(xcd, data, parent_folder=\"data/xeno-canto\", workers = 4):\n", + " def prep_download(args):\n", + " url = args[0]\n", + " file_path = args[1]\n", + " return download_file(url, file_path)\n", + "\n", + " os.makedirs(parent_folder, exist_ok=True)\n", + "\n", + " if \"recordings\" in data[0]:\n", + " data = xcd.concat_recording_data(data) \n", + " download_data = [\n", + " (recording[\"file\"], Path(parent_folder) / Path(recording[\"file-name\"]))\n", + " for recording in data\n", + " ]\n", + " pool = ThreadPool(workers)\n", + " results = pool.imap_unordered(prep_download, download_data) \n", + " pool.close()\n", + " return results\n", + "\n", + "download_files(xcd, data)" + ] + }, + { + "cell_type": "markdown", + "id": "ea02004c", + "metadata": {}, + "source": [ + "# Study" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5bf99a36", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "recordings = xcd.concat_recording_data(data)\n", + "df = pd.DataFrame(recordings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14a73c0f", + "metadata": {}, + "outputs": [], + "source": [ + "!uv add --optional notebooks seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4e26ec8", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "import seaborn as sns\n", + "# df[\"en\"].value_counts().hist()\n", + "\n", + "\n", + "sns.histplot(df[\"en\"].value_counts())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9d7dc37", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.ylabel(\"Number of Species\")\n", + "plt.xlabel(\"Number of Indivuals Per Species\")\n", + "plt.title(\"Do We Have a Few-shot Learning Problem for XC in Southern California?\")\n", + "df[\"en\"].value_counts().hist()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c04ef6f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "whoot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/data_downloader/get_more_species_data.ipynb b/data_downloader/get_more_species_data.ipynb new file mode 100644 index 0000000..b9db482 --- /dev/null +++ b/data_downloader/get_more_species_data.ipynb @@ -0,0 +1,184 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "cac5b89f", + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "from xc import XenoCantoDownloader\n", + "from dotenv import load_dotenv\n", + "import os\n", + "\n", + "# Load environment variables from the .env file\n", + "load_dotenv()\n", + "\n", + "xcd = XenoCantoDownloader(api_key=os.environ[\"XC_API_KEY\"])\n", + "\n", + "\n", + "import librosa\n", + "# %%\n", + "import json\n", + "\n", + "with open(\"../data/san_diego_xc_aux/xc_meta_aux.json\", mode=\"r\") as f:\n", + " data = json.load(f)\n", + " # json.dump(data, f, indent=4)\n", + "\n", + "# %%\n", + "import requests\n", + "\n", + "# %%\n", + "import shutil\n", + "import os\n", + "from pathlib import Path\n", + "from multiprocessing.pool import ThreadPool\n", + "import tqdm\n", + "\n", + "# # https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests\n", + "# def download_file(url, local_filename, dry_run=False):\n", + "# if os.path.exists(local_filename):\n", + "# try:\n", + "# librosa.load(path=local_filename)\n", + "# return local_filename\n", + "# except Exception as e:\n", + "# pass\n", + " \n", + "# try:\n", + "# with requests.get(url, stream=True) as r:\n", + "# with open(local_filename, 'wb') as f:\n", + "# if not dry_run:\n", + "# shutil.copyfileobj(r.raw, f)\n", + "# else:\n", + "# print(local_filename)\n", + "\n", + "# return local_filename\n", + "# except Exception as e:\n", + "# print(e, flush=True)\n", + "# return None\n", + "\n", + "\n", + "def prep_download(args, dry_run=False):\n", + " url = args[0]\n", + " local_filename = args[1]\n", + " if os.path.exists(local_filename):\n", + " try:\n", + " librosa.load(path=local_filename)\n", + " return local_filename\n", + " except Exception as e:\n", + " print(local_filename, e, \"bad file, remake\")\n", + "\n", + " try:\n", + " with requests.get(url, stream=True) as r:\n", + " with open(local_filename, 'wb') as f:\n", + " if not dry_run:\n", + " shutil.copyfileobj(r.raw, f)\n", + " else:\n", + " print(local_filename)\n", + "\n", + " return local_filename\n", + " except Exception as e:\n", + " print(local_filename, e, flush=True)\n", + " return None\n", + "\n", + "def download_files(xcd, data, parent_folder=\"../data/san_diego_xc_aux/xeno-canto\", workers = 2):\n", + " \n", + "\n", + " os.makedirs(parent_folder, exist_ok=True)\n", + "\n", + " if \"recordings\" in data[0]:\n", + " data = xcd.concat_recording_data(data) \n", + " download_data = [\n", + " (recording[\"file\"], Path(parent_folder) / Path(recording[\"file-name\"].replace(\"/\", \"_\")))\n", + " for recording in data\n", + " ]\n", + "\n", + " with ThreadPool(processes=1024) as pool:\n", + " print(\"Main process: Submitting tasks...\")\n", + " \n", + " # Iterate over the results to wait for all tasks to complete.\n", + " # This loop will block until all tasks are finished.\n", + " for result in tqdm.tqdm(pool.imap_unordered(prep_download, download_data), total=len(download_data)):\n", + " if result is None:\n", + " print(\"ISSUE\")\n", + " \n", + " return results\n", + "\n", + "results = download_files(xcd, data)\n", + "results\n", + "\n", + "# %%\n", + "import pandas as pd\n", + "recordings = xcd.concat_recording_data(data)\n", + "df = pd.DataFrame(recordings)\n", + "\n", + "df.shape\n", + "\n", + "# %%\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce6bcf04", + "metadata": {}, + "outputs": [], + "source": [ + "132510 / 303" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aeee1789", + "metadata": {}, + "outputs": [], + "source": [ + "df[\"en\"].value_counts()[df[\"en\"].value_counts() < 1000].hist(bins=50)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0863f6a", + "metadata": {}, + "outputs": [], + "source": [ + "df[\"grp\"].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59bad81f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "whoot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From ed11edc6ba7b0e86dc9490fb6d8766f3b0d2021a Mon Sep 17 00:00:00 2001 From: Sean Perry Date: Wed, 15 Oct 2025 11:49:16 -0700 Subject: [PATCH 04/18] Bird mae (#94) * Added some experimental inferance pipeline Needed this to test unseen burrowing owl dataset. Will help alot with getting data formatted correctly * Add testing code for loading in models for inferance * feat: add inferance pipeline * Block CSVs from commits * Move visualization packages to optional dependecy * Clean code (remove comments and remove run_spefific code from library) * Clean up timm_model * Add BirdMAE Model Training --- .gitignore | 1 + pyproject.toml | 5 + timm_check.ipynb | 3682 +++++++++++++++++ whoot_model_training/inferance.py | 122 + whoot_model_training/train.py | 36 +- .../data_extractor/__init__.py | 5 +- .../data_extractor/raw_audio_extractor.py | 247 ++ .../whoot_model_training/models/__init__.py | 4 + .../whoot_model_training/models/hf_models.py | 142 + .../whoot_model_training/models/timm_model.py | 4 +- .../preprocessors/__init__.py | 3 +- .../preprocessors/base_preprocessor.py | 1 + .../preprocessors/inferance_wrap.py | 7 + .../spectrogram_preprocessors.py | 4 +- .../preprocessors/waveform_preprocessors.py | 6 +- .../whoot_model_training/trainer.py | 17 +- 16 files changed, 4268 insertions(+), 18 deletions(-) create mode 100644 timm_check.ipynb create mode 100644 whoot_model_training/inferance.py create mode 100644 whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py create mode 100644 whoot_model_training/whoot_model_training/models/hf_models.py create mode 100644 whoot_model_training/whoot_model_training/preprocessors/inferance_wrap.py diff --git a/.gitignore b/.gitignore index b5bb093..d5ced77 100644 --- a/.gitignore +++ b/.gitignore @@ -226,3 +226,4 @@ settings.json # Block all configs besides the example config whoot_model_training/configs !whoot_model_training/configs/config.yml +*.csv diff --git a/pyproject.toml b/pyproject.toml index aab0492..3382876 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,10 +29,12 @@ version = {attr = "whoot.__version__"} [project.optional-dependencies] cpu = [ "torch>=2.7.0", + "torchaudio>=2.8.0", "torchvision>=0.22.0", ] cu128 = [ "torch>=2.7.0", + "torchaudio>=2.8.0", "torchvision>=0.22.0", ] model-training = [ @@ -54,6 +56,9 @@ notebooks = [ "matplotlib>=3.10.6", "seaborn>=0.13.2", ] +birdnet = [ + "birdnet>=0.1.7", +] [packages.index] diff --git a/timm_check.ipynb b/timm_check.ipynb new file mode 100644 index 0000000..d4b2c02 --- /dev/null +++ b/timm_check.ipynb @@ -0,0 +1,3682 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 63, + "id": "4877adf6", + "metadata": {}, + "outputs": [], + "source": [ + "import timm" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "cd16456e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['aimv2_1b_patch14_224',\n", + " 'aimv2_1b_patch14_336',\n", + " 'aimv2_1b_patch14_448',\n", + " 'aimv2_3b_patch14_224',\n", + " 'aimv2_3b_patch14_336',\n", + " 'aimv2_3b_patch14_448',\n", + " 'aimv2_huge_patch14_224',\n", + " 'aimv2_huge_patch14_336',\n", + " 'aimv2_huge_patch14_448',\n", + " 'aimv2_large_patch14_224',\n", + " 'aimv2_large_patch14_336',\n", + " 'aimv2_large_patch14_448',\n", + " 'bat_resnext26ts',\n", + " 'beit3_base_patch16_224',\n", + " 'beit3_giant_patch14_224',\n", + " 'beit3_giant_patch14_336',\n", + " 'beit3_large_patch16_224',\n", + " 'beit_base_patch16_224',\n", + " 'beit_base_patch16_384',\n", + " 'beit_large_patch16_224',\n", + " 'beit_large_patch16_384',\n", + " 'beit_large_patch16_512',\n", + " 'beitv2_base_patch16_224',\n", + " 'beitv2_large_patch16_224',\n", + " 'botnet26t_256',\n", + " 'botnet50ts_256',\n", + " 'caformer_b36',\n", + " 'caformer_m36',\n", + " 'caformer_s18',\n", + " 'caformer_s36',\n", + " 'cait_m36_384',\n", + " 'cait_m48_448',\n", + " 'cait_s24_224',\n", + " 'cait_s24_384',\n", + " 'cait_s36_384',\n", + " 'cait_xs24_384',\n", + " 'cait_xxs24_224',\n", + " 'cait_xxs24_384',\n", + " 'cait_xxs36_224',\n", + " 'cait_xxs36_384',\n", + " 'coat_lite_medium',\n", + " 'coat_lite_medium_384',\n", + " 'coat_lite_mini',\n", + " 'coat_lite_small',\n", + " 'coat_lite_tiny',\n", + " 'coat_mini',\n", + " 'coat_small',\n", + " 'coat_tiny',\n", + " 'coatnet_0_224',\n", + " 'coatnet_0_rw_224',\n", + " 'coatnet_1_224',\n", + " 'coatnet_1_rw_224',\n", + " 'coatnet_2_224',\n", + " 'coatnet_2_rw_224',\n", + " 'coatnet_3_224',\n", + " 'coatnet_3_rw_224',\n", + " 'coatnet_4_224',\n", + " 'coatnet_5_224',\n", + " 'coatnet_bn_0_rw_224',\n", + " 'coatnet_nano_cc_224',\n", + " 'coatnet_nano_rw_224',\n", + " 'coatnet_pico_rw_224',\n", + " 'coatnet_rmlp_0_rw_224',\n", + " 'coatnet_rmlp_1_rw2_224',\n", + " 'coatnet_rmlp_1_rw_224',\n", + " 'coatnet_rmlp_2_rw_224',\n", + " 'coatnet_rmlp_2_rw_384',\n", + " 'coatnet_rmlp_3_rw_224',\n", + " 'coatnet_rmlp_nano_rw_224',\n", + " 'coatnext_nano_rw_224',\n", + " 'convformer_b36',\n", + " 'convformer_m36',\n", + " 'convformer_s18',\n", + " 'convformer_s36',\n", + " 'convit_base',\n", + " 'convit_small',\n", + " 'convit_tiny',\n", + " 'convmixer_768_32',\n", + " 'convmixer_1024_20_ks9_p14',\n", + " 'convmixer_1536_20',\n", + " 'convnext_atto',\n", + " 'convnext_atto_ols',\n", + " 'convnext_atto_rms',\n", + " 'convnext_base',\n", + " 'convnext_femto',\n", + " 'convnext_femto_ols',\n", + " 'convnext_large',\n", + " 'convnext_large_mlp',\n", + " 'convnext_nano',\n", + " 'convnext_nano_ols',\n", + " 'convnext_pico',\n", + " 'convnext_pico_ols',\n", + " 'convnext_small',\n", + " 'convnext_tiny',\n", + " 'convnext_tiny_hnf',\n", + " 'convnext_xlarge',\n", + " 'convnext_xxlarge',\n", + " 'convnext_zepto_rms',\n", + " 'convnext_zepto_rms_ols',\n", + " 'convnextv2_atto',\n", + " 'convnextv2_base',\n", + " 'convnextv2_femto',\n", + " 'convnextv2_huge',\n", + " 'convnextv2_large',\n", + " 'convnextv2_nano',\n", + " 'convnextv2_pico',\n", + " 'convnextv2_small',\n", + " 'convnextv2_tiny',\n", + " 'crossvit_9_240',\n", + " 'crossvit_9_dagger_240',\n", + " 'crossvit_15_240',\n", + " 'crossvit_15_dagger_240',\n", + " 'crossvit_15_dagger_408',\n", + " 'crossvit_18_240',\n", + " 'crossvit_18_dagger_240',\n", + " 'crossvit_18_dagger_408',\n", + " 'crossvit_base_240',\n", + " 'crossvit_small_240',\n", + " 'crossvit_tiny_240',\n", + " 'cs3darknet_focus_l',\n", + " 'cs3darknet_focus_m',\n", + " 'cs3darknet_focus_s',\n", + " 'cs3darknet_focus_x',\n", + " 'cs3darknet_l',\n", + " 'cs3darknet_m',\n", + " 'cs3darknet_s',\n", + " 'cs3darknet_x',\n", + " 'cs3edgenet_x',\n", + " 'cs3se_edgenet_x',\n", + " 'cs3sedarknet_l',\n", + " 'cs3sedarknet_x',\n", + " 'cs3sedarknet_xdw',\n", + " 'cspdarknet53',\n", + " 'cspresnet50',\n", + " 'cspresnet50d',\n", + " 'cspresnet50w',\n", + " 'cspresnext50',\n", + " 'darknet17',\n", + " 'darknet21',\n", + " 'darknet53',\n", + " 'darknetaa53',\n", + " 'davit_base',\n", + " 'davit_base_fl',\n", + " 'davit_giant',\n", + " 'davit_huge',\n", + " 'davit_huge_fl',\n", + " 'davit_large',\n", + " 'davit_small',\n", + " 'davit_tiny',\n", + " 'deit3_base_patch16_224',\n", + " 'deit3_base_patch16_384',\n", + " 'deit3_huge_patch14_224',\n", + " 'deit3_large_patch16_224',\n", + " 'deit3_large_patch16_384',\n", + " 'deit3_medium_patch16_224',\n", + " 'deit3_small_patch16_224',\n", + " 'deit3_small_patch16_384',\n", + " 'deit_base_distilled_patch16_224',\n", + " 'deit_base_distilled_patch16_384',\n", + " 'deit_base_patch16_224',\n", + " 'deit_base_patch16_384',\n", + " 'deit_small_distilled_patch16_224',\n", + " 'deit_small_patch16_224',\n", + " 'deit_tiny_distilled_patch16_224',\n", + " 'deit_tiny_patch16_224',\n", + " 'densenet121',\n", + " 'densenet161',\n", + " 'densenet169',\n", + " 'densenet201',\n", + " 'densenet264d',\n", + " 'densenetblur121d',\n", + " 'dla34',\n", + " 'dla46_c',\n", + " 'dla46x_c',\n", + " 'dla60',\n", + " 'dla60_res2net',\n", + " 'dla60_res2next',\n", + " 'dla60x',\n", + " 'dla60x_c',\n", + " 'dla102',\n", + " 'dla102x',\n", + " 'dla102x2',\n", + " 'dla169',\n", + " 'dm_nfnet_f0',\n", + " 'dm_nfnet_f1',\n", + " 'dm_nfnet_f2',\n", + " 'dm_nfnet_f3',\n", + " 'dm_nfnet_f4',\n", + " 'dm_nfnet_f5',\n", + " 'dm_nfnet_f6',\n", + " 'dpn48b',\n", + " 'dpn68',\n", + " 'dpn68b',\n", + " 'dpn92',\n", + " 'dpn98',\n", + " 'dpn107',\n", + " 'dpn131',\n", + " 'eca_botnext26ts_256',\n", + " 'eca_halonext26ts',\n", + " 'eca_nfnet_l0',\n", + " 'eca_nfnet_l1',\n", + " 'eca_nfnet_l2',\n", + " 'eca_nfnet_l3',\n", + " 'eca_resnet33ts',\n", + " 'eca_resnext26ts',\n", + " 'eca_vovnet39b',\n", + " 'ecaresnet26t',\n", + " 'ecaresnet50d',\n", + " 'ecaresnet50d_pruned',\n", + " 'ecaresnet50t',\n", + " 'ecaresnet101d',\n", + " 'ecaresnet101d_pruned',\n", + " 'ecaresnet200d',\n", + " 'ecaresnet269d',\n", + " 'ecaresnetlight',\n", + " 'ecaresnext26t_32x4d',\n", + " 'ecaresnext50t_32x4d',\n", + " 'edgenext_base',\n", + " 'edgenext_small',\n", + " 'edgenext_small_rw',\n", + " 'edgenext_x_small',\n", + " 'edgenext_xx_small',\n", + " 'efficientformer_l1',\n", + " 'efficientformer_l3',\n", + " 'efficientformer_l7',\n", + " 'efficientformerv2_l',\n", + " 'efficientformerv2_s0',\n", + " 'efficientformerv2_s1',\n", + " 'efficientformerv2_s2',\n", + " 'efficientnet_b0',\n", + " 'efficientnet_b0_g8_gn',\n", + " 'efficientnet_b0_g16_evos',\n", + " 'efficientnet_b0_gn',\n", + " 'efficientnet_b1',\n", + " 'efficientnet_b1_pruned',\n", + " 'efficientnet_b2',\n", + " 'efficientnet_b2_pruned',\n", + " 'efficientnet_b3',\n", + " 'efficientnet_b3_g8_gn',\n", + " 'efficientnet_b3_gn',\n", + " 'efficientnet_b3_pruned',\n", + " 'efficientnet_b4',\n", + " 'efficientnet_b5',\n", + " 'efficientnet_b6',\n", + " 'efficientnet_b7',\n", + " 'efficientnet_b8',\n", + " 'efficientnet_blur_b0',\n", + " 'efficientnet_cc_b0_4e',\n", + " 'efficientnet_cc_b0_8e',\n", + " 'efficientnet_cc_b1_8e',\n", + " 'efficientnet_el',\n", + " 'efficientnet_el_pruned',\n", + " 'efficientnet_em',\n", + " 'efficientnet_es',\n", + " 'efficientnet_es_pruned',\n", + " 'efficientnet_h_b5',\n", + " 'efficientnet_l2',\n", + " 'efficientnet_lite0',\n", + " 'efficientnet_lite1',\n", + " 'efficientnet_lite2',\n", + " 'efficientnet_lite3',\n", + " 'efficientnet_lite4',\n", + " 'efficientnet_x_b3',\n", + " 'efficientnet_x_b5',\n", + " 'efficientnetv2_l',\n", + " 'efficientnetv2_m',\n", + " 'efficientnetv2_rw_m',\n", + " 'efficientnetv2_rw_s',\n", + " 'efficientnetv2_rw_t',\n", + " 'efficientnetv2_s',\n", + " 'efficientnetv2_xl',\n", + " 'efficientvit_b0',\n", + " 'efficientvit_b1',\n", + " 'efficientvit_b2',\n", + " 'efficientvit_b3',\n", + " 'efficientvit_l1',\n", + " 'efficientvit_l2',\n", + " 'efficientvit_l3',\n", + " 'efficientvit_m0',\n", + " 'efficientvit_m1',\n", + " 'efficientvit_m2',\n", + " 'efficientvit_m3',\n", + " 'efficientvit_m4',\n", + " 'efficientvit_m5',\n", + " 'ese_vovnet19b_dw',\n", + " 'ese_vovnet19b_slim',\n", + " 'ese_vovnet19b_slim_dw',\n", + " 'ese_vovnet39b',\n", + " 'ese_vovnet39b_evos',\n", + " 'ese_vovnet57b',\n", + " 'ese_vovnet99b',\n", + " 'eva02_base_patch14_224',\n", + " 'eva02_base_patch14_448',\n", + " 'eva02_base_patch16_clip_224',\n", + " 'eva02_enormous_patch14_clip_224',\n", + " 'eva02_large_patch14_224',\n", + " 'eva02_large_patch14_448',\n", + " 'eva02_large_patch14_clip_224',\n", + " 'eva02_large_patch14_clip_336',\n", + " 'eva02_small_patch14_224',\n", + " 'eva02_small_patch14_336',\n", + " 'eva02_tiny_patch14_224',\n", + " 'eva02_tiny_patch14_336',\n", + " 'eva_giant_patch14_224',\n", + " 'eva_giant_patch14_336',\n", + " 'eva_giant_patch14_560',\n", + " 'eva_giant_patch14_clip_224',\n", + " 'eva_large_patch14_196',\n", + " 'eva_large_patch14_336',\n", + " 'fasternet_l',\n", + " 'fasternet_m',\n", + " 'fasternet_s',\n", + " 'fasternet_t0',\n", + " 'fasternet_t1',\n", + " 'fasternet_t2',\n", + " 'fastvit_ma36',\n", + " 'fastvit_mci0',\n", + " 'fastvit_mci1',\n", + " 'fastvit_mci2',\n", + " 'fastvit_s12',\n", + " 'fastvit_sa12',\n", + " 'fastvit_sa24',\n", + " 'fastvit_sa36',\n", + " 'fastvit_t8',\n", + " 'fastvit_t12',\n", + " 'fbnetc_100',\n", + " 'fbnetv3_b',\n", + " 'fbnetv3_d',\n", + " 'fbnetv3_g',\n", + " 'flexivit_base',\n", + " 'flexivit_large',\n", + " 'flexivit_small',\n", + " 'focalnet_base_lrf',\n", + " 'focalnet_base_srf',\n", + " 'focalnet_huge_fl3',\n", + " 'focalnet_huge_fl4',\n", + " 'focalnet_large_fl3',\n", + " 'focalnet_large_fl4',\n", + " 'focalnet_small_lrf',\n", + " 'focalnet_small_srf',\n", + " 'focalnet_tiny_lrf',\n", + " 'focalnet_tiny_srf',\n", + " 'focalnet_xlarge_fl3',\n", + " 'focalnet_xlarge_fl4',\n", + " 'gc_efficientnetv2_rw_t',\n", + " 'gcresnet33ts',\n", + " 'gcresnet50t',\n", + " 'gcresnext26ts',\n", + " 'gcresnext50ts',\n", + " 'gcvit_base',\n", + " 'gcvit_small',\n", + " 'gcvit_tiny',\n", + " 'gcvit_xtiny',\n", + " 'gcvit_xxtiny',\n", + " 'gernet_l',\n", + " 'gernet_m',\n", + " 'gernet_s',\n", + " 'ghostnet_050',\n", + " 'ghostnet_100',\n", + " 'ghostnet_130',\n", + " 'ghostnetv2_100',\n", + " 'ghostnetv2_130',\n", + " 'ghostnetv2_160',\n", + " 'ghostnetv3_050',\n", + " 'ghostnetv3_100',\n", + " 'ghostnetv3_130',\n", + " 'ghostnetv3_160',\n", + " 'gmixer_12_224',\n", + " 'gmixer_24_224',\n", + " 'gmlp_b16_224',\n", + " 'gmlp_s16_224',\n", + " 'gmlp_ti16_224',\n", + " 'halo2botnet50ts_256',\n", + " 'halonet26t',\n", + " 'halonet50ts',\n", + " 'halonet_h1',\n", + " 'haloregnetz_b',\n", + " 'hardcorenas_a',\n", + " 'hardcorenas_b',\n", + " 'hardcorenas_c',\n", + " 'hardcorenas_d',\n", + " 'hardcorenas_e',\n", + " 'hardcorenas_f',\n", + " 'hgnet_base',\n", + " 'hgnet_small',\n", + " 'hgnet_tiny',\n", + " 'hgnetv2_b0',\n", + " 'hgnetv2_b1',\n", + " 'hgnetv2_b2',\n", + " 'hgnetv2_b3',\n", + " 'hgnetv2_b4',\n", + " 'hgnetv2_b5',\n", + " 'hgnetv2_b6',\n", + " 'hiera_base_224',\n", + " 'hiera_base_abswin_256',\n", + " 'hiera_base_plus_224',\n", + " 'hiera_huge_224',\n", + " 'hiera_large_224',\n", + " 'hiera_small_224',\n", + " 'hiera_small_abswin_256',\n", + " 'hiera_tiny_224',\n", + " 'hieradet_small',\n", + " 'hrnet_w18',\n", + " 'hrnet_w18_small',\n", + " 'hrnet_w18_small_v2',\n", + " 'hrnet_w18_ssld',\n", + " 'hrnet_w30',\n", + " 'hrnet_w32',\n", + " 'hrnet_w40',\n", + " 'hrnet_w44',\n", + " 'hrnet_w48',\n", + " 'hrnet_w48_ssld',\n", + " 'hrnet_w64',\n", + " 'inception_next_atto',\n", + " 'inception_next_base',\n", + " 'inception_next_small',\n", + " 'inception_next_tiny',\n", + " 'inception_resnet_v2',\n", + " 'inception_v3',\n", + " 'inception_v4',\n", + " 'lambda_resnet26rpt_256',\n", + " 'lambda_resnet26t',\n", + " 'lambda_resnet50ts',\n", + " 'lamhalobotnet50ts_256',\n", + " 'lcnet_035',\n", + " 'lcnet_050',\n", + " 'lcnet_075',\n", + " 'lcnet_100',\n", + " 'lcnet_150',\n", + " 'legacy_senet154',\n", + " 'legacy_seresnet18',\n", + " 'legacy_seresnet34',\n", + " 'legacy_seresnet50',\n", + " 'legacy_seresnet101',\n", + " 'legacy_seresnet152',\n", + " 'legacy_seresnext26_32x4d',\n", + " 'legacy_seresnext50_32x4d',\n", + " 'legacy_seresnext101_32x4d',\n", + " 'legacy_xception',\n", + " 'levit_128',\n", + " 'levit_128s',\n", + " 'levit_192',\n", + " 'levit_256',\n", + " 'levit_256d',\n", + " 'levit_384',\n", + " 'levit_384_s8',\n", + " 'levit_512',\n", + " 'levit_512_s8',\n", + " 'levit_512d',\n", + " 'levit_conv_128',\n", + " 'levit_conv_128s',\n", + " 'levit_conv_192',\n", + " 'levit_conv_256',\n", + " 'levit_conv_256d',\n", + " 'levit_conv_384',\n", + " 'levit_conv_384_s8',\n", + " 'levit_conv_512',\n", + " 'levit_conv_512_s8',\n", + " 'levit_conv_512d',\n", + " 'mambaout_base',\n", + " 'mambaout_base_plus_rw',\n", + " 'mambaout_base_short_rw',\n", + " 'mambaout_base_tall_rw',\n", + " 'mambaout_base_wide_rw',\n", + " 'mambaout_femto',\n", + " 'mambaout_kobe',\n", + " 'mambaout_small',\n", + " 'mambaout_small_rw',\n", + " 'mambaout_tiny',\n", + " 'maxvit_base_tf_224',\n", + " 'maxvit_base_tf_384',\n", + " 'maxvit_base_tf_512',\n", + " 'maxvit_large_tf_224',\n", + " 'maxvit_large_tf_384',\n", + " 'maxvit_large_tf_512',\n", + " 'maxvit_nano_rw_256',\n", + " 'maxvit_pico_rw_256',\n", + " 'maxvit_rmlp_base_rw_224',\n", + " 'maxvit_rmlp_base_rw_384',\n", + " 'maxvit_rmlp_nano_rw_256',\n", + " 'maxvit_rmlp_pico_rw_256',\n", + " 'maxvit_rmlp_small_rw_224',\n", + " 'maxvit_rmlp_small_rw_256',\n", + " 'maxvit_rmlp_tiny_rw_256',\n", + " 'maxvit_small_tf_224',\n", + " 'maxvit_small_tf_384',\n", + " 'maxvit_small_tf_512',\n", + " 'maxvit_tiny_pm_256',\n", + " 'maxvit_tiny_rw_224',\n", + " 'maxvit_tiny_rw_256',\n", + " 'maxvit_tiny_tf_224',\n", + " 'maxvit_tiny_tf_384',\n", + " 'maxvit_tiny_tf_512',\n", + " 'maxvit_xlarge_tf_224',\n", + " 'maxvit_xlarge_tf_384',\n", + " 'maxvit_xlarge_tf_512',\n", + " 'maxxvit_rmlp_nano_rw_256',\n", + " 'maxxvit_rmlp_small_rw_256',\n", + " 'maxxvit_rmlp_tiny_rw_256',\n", + " 'maxxvitv2_nano_rw_256',\n", + " 'maxxvitv2_rmlp_base_rw_224',\n", + " 'maxxvitv2_rmlp_base_rw_384',\n", + " 'maxxvitv2_rmlp_large_rw_224',\n", + " 'mixer_b16_224',\n", + " 'mixer_b32_224',\n", + " 'mixer_l16_224',\n", + " 'mixer_l32_224',\n", + " 'mixer_s16_224',\n", + " 'mixer_s32_224',\n", + " 'mixnet_l',\n", + " 'mixnet_m',\n", + " 'mixnet_s',\n", + " 'mixnet_xl',\n", + " 'mixnet_xxl',\n", + " 'mnasnet_050',\n", + " 'mnasnet_075',\n", + " 'mnasnet_100',\n", + " 'mnasnet_140',\n", + " 'mnasnet_small',\n", + " 'mobilenet_edgetpu_100',\n", + " 'mobilenet_edgetpu_v2_l',\n", + " 'mobilenet_edgetpu_v2_m',\n", + " 'mobilenet_edgetpu_v2_s',\n", + " 'mobilenet_edgetpu_v2_xs',\n", + " 'mobilenetv1_100',\n", + " 'mobilenetv1_100h',\n", + " 'mobilenetv1_125',\n", + " 'mobilenetv2_035',\n", + " 'mobilenetv2_050',\n", + " 'mobilenetv2_075',\n", + " 'mobilenetv2_100',\n", + " 'mobilenetv2_110d',\n", + " 'mobilenetv2_120d',\n", + " 'mobilenetv2_140',\n", + " 'mobilenetv3_large_075',\n", + " 'mobilenetv3_large_100',\n", + " 'mobilenetv3_large_150d',\n", + " 'mobilenetv3_rw',\n", + " 'mobilenetv3_small_050',\n", + " 'mobilenetv3_small_075',\n", + " 'mobilenetv3_small_100',\n", + " 'mobilenetv4_conv_aa_large',\n", + " 'mobilenetv4_conv_aa_medium',\n", + " 'mobilenetv4_conv_blur_medium',\n", + " 'mobilenetv4_conv_large',\n", + " 'mobilenetv4_conv_medium',\n", + " 'mobilenetv4_conv_small',\n", + " 'mobilenetv4_conv_small_035',\n", + " 'mobilenetv4_conv_small_050',\n", + " 'mobilenetv4_hybrid_large',\n", + " 'mobilenetv4_hybrid_large_075',\n", + " 'mobilenetv4_hybrid_medium',\n", + " 'mobilenetv4_hybrid_medium_075',\n", + " 'mobilenetv5_300m',\n", + " 'mobilenetv5_300m_enc',\n", + " 'mobilenetv5_base',\n", + " 'mobileone_s0',\n", + " 'mobileone_s1',\n", + " 'mobileone_s2',\n", + " 'mobileone_s3',\n", + " 'mobileone_s4',\n", + " 'mobilevit_s',\n", + " 'mobilevit_xs',\n", + " 'mobilevit_xxs',\n", + " 'mobilevitv2_050',\n", + " 'mobilevitv2_075',\n", + " 'mobilevitv2_100',\n", + " 'mobilevitv2_125',\n", + " 'mobilevitv2_150',\n", + " 'mobilevitv2_175',\n", + " 'mobilevitv2_200',\n", + " 'mvitv2_base',\n", + " 'mvitv2_base_cls',\n", + " 'mvitv2_huge_cls',\n", + " 'mvitv2_large',\n", + " 'mvitv2_large_cls',\n", + " 'mvitv2_small',\n", + " 'mvitv2_small_cls',\n", + " 'mvitv2_tiny',\n", + " 'naflexvit_base_patch16_gap',\n", + " 'naflexvit_base_patch16_map',\n", + " 'naflexvit_base_patch16_par_gap',\n", + " 'naflexvit_base_patch16_parfac_gap',\n", + " 'naflexvit_base_patch16_siglip',\n", + " 'naflexvit_so150m2_patch16_reg1_gap',\n", + " 'naflexvit_so150m2_patch16_reg1_map',\n", + " 'naflexvit_so400m_patch16_siglip',\n", + " 'nasnetalarge',\n", + " 'nest_base',\n", + " 'nest_base_jx',\n", + " 'nest_small',\n", + " 'nest_small_jx',\n", + " 'nest_tiny',\n", + " 'nest_tiny_jx',\n", + " 'nextvit_base',\n", + " 'nextvit_large',\n", + " 'nextvit_small',\n", + " 'nf_ecaresnet26',\n", + " 'nf_ecaresnet50',\n", + " 'nf_ecaresnet101',\n", + " 'nf_regnet_b0',\n", + " 'nf_regnet_b1',\n", + " 'nf_regnet_b2',\n", + " 'nf_regnet_b3',\n", + " 'nf_regnet_b4',\n", + " 'nf_regnet_b5',\n", + " 'nf_resnet26',\n", + " 'nf_resnet50',\n", + " 'nf_resnet101',\n", + " 'nf_seresnet26',\n", + " 'nf_seresnet50',\n", + " 'nf_seresnet101',\n", + " 'nfnet_f0',\n", + " 'nfnet_f1',\n", + " 'nfnet_f2',\n", + " 'nfnet_f3',\n", + " 'nfnet_f4',\n", + " 'nfnet_f5',\n", + " 'nfnet_f6',\n", + " 'nfnet_f7',\n", + " 'nfnet_l0',\n", + " 'pit_b_224',\n", + " 'pit_b_distilled_224',\n", + " 'pit_s_224',\n", + " 'pit_s_distilled_224',\n", + " 'pit_ti_224',\n", + " 'pit_ti_distilled_224',\n", + " 'pit_xs_224',\n", + " 'pit_xs_distilled_224',\n", + " 'pnasnet5large',\n", + " 'poolformer_m36',\n", + " 'poolformer_m48',\n", + " 'poolformer_s12',\n", + " 'poolformer_s24',\n", + " 'poolformer_s36',\n", + " 'poolformerv2_m36',\n", + " 'poolformerv2_m48',\n", + " 'poolformerv2_s12',\n", + " 'poolformerv2_s24',\n", + " 'poolformerv2_s36',\n", + " 'pvt_v2_b0',\n", + " 'pvt_v2_b1',\n", + " 'pvt_v2_b2',\n", + " 'pvt_v2_b2_li',\n", + " 'pvt_v2_b3',\n", + " 'pvt_v2_b4',\n", + " 'pvt_v2_b5',\n", + " 'rdnet_base',\n", + " 'rdnet_large',\n", + " 'rdnet_small',\n", + " 'rdnet_tiny',\n", + " 'regnetv_040',\n", + " 'regnetv_064',\n", + " 'regnetx_002',\n", + " 'regnetx_004',\n", + " 'regnetx_004_tv',\n", + " 'regnetx_006',\n", + " 'regnetx_008',\n", + " 'regnetx_016',\n", + " 'regnetx_032',\n", + " 'regnetx_040',\n", + " 'regnetx_064',\n", + " 'regnetx_080',\n", + " 'regnetx_120',\n", + " 'regnetx_160',\n", + " 'regnetx_320',\n", + " 'regnety_002',\n", + " 'regnety_004',\n", + " 'regnety_006',\n", + " 'regnety_008',\n", + " 'regnety_008_tv',\n", + " 'regnety_016',\n", + " 'regnety_032',\n", + " 'regnety_040',\n", + " 'regnety_040_sgn',\n", + " 'regnety_064',\n", + " 'regnety_080',\n", + " 'regnety_080_tv',\n", + " 'regnety_120',\n", + " 'regnety_160',\n", + " 'regnety_320',\n", + " 'regnety_640',\n", + " 'regnety_1280',\n", + " 'regnety_2560',\n", + " 'regnetz_005',\n", + " 'regnetz_040',\n", + " 'regnetz_040_h',\n", + " 'regnetz_b16',\n", + " 'regnetz_b16_evos',\n", + " 'regnetz_c16',\n", + " 'regnetz_c16_evos',\n", + " 'regnetz_d8',\n", + " 'regnetz_d8_evos',\n", + " 'regnetz_d32',\n", + " 'regnetz_e8',\n", + " 'repghostnet_050',\n", + " 'repghostnet_058',\n", + " 'repghostnet_080',\n", + " 'repghostnet_100',\n", + " 'repghostnet_111',\n", + " 'repghostnet_130',\n", + " 'repghostnet_150',\n", + " 'repghostnet_200',\n", + " 'repvgg_a0',\n", + " 'repvgg_a1',\n", + " 'repvgg_a2',\n", + " 'repvgg_b0',\n", + " 'repvgg_b1',\n", + " 'repvgg_b1g4',\n", + " 'repvgg_b2',\n", + " 'repvgg_b2g4',\n", + " 'repvgg_b3',\n", + " 'repvgg_b3g4',\n", + " 'repvgg_d2se',\n", + " 'repvit_m0_9',\n", + " 'repvit_m1',\n", + " 'repvit_m1_0',\n", + " 'repvit_m1_1',\n", + " 'repvit_m1_5',\n", + " 'repvit_m2',\n", + " 'repvit_m2_3',\n", + " 'repvit_m3',\n", + " 'res2net50_14w_8s',\n", + " 'res2net50_26w_4s',\n", + " 'res2net50_26w_6s',\n", + " 'res2net50_26w_8s',\n", + " 'res2net50_48w_2s',\n", + " 'res2net50d',\n", + " 'res2net101_26w_4s',\n", + " 'res2net101d',\n", + " 'res2next50',\n", + " 'resmlp_12_224',\n", + " 'resmlp_24_224',\n", + " 'resmlp_36_224',\n", + " 'resmlp_big_24_224',\n", + " 'resnest14d',\n", + " 'resnest26d',\n", + " 'resnest50d',\n", + " 'resnest50d_1s4x24d',\n", + " 'resnest50d_4s2x40d',\n", + " 'resnest101e',\n", + " 'resnest200e',\n", + " 'resnest269e',\n", + " 'resnet10t',\n", + " 'resnet14t',\n", + " 'resnet18',\n", + " 'resnet18d',\n", + " 'resnet26',\n", + " 'resnet26d',\n", + " 'resnet26t',\n", + " 'resnet32ts',\n", + " 'resnet33ts',\n", + " 'resnet34',\n", + " 'resnet34d',\n", + " 'resnet50',\n", + " 'resnet50_clip',\n", + " 'resnet50_clip_gap',\n", + " 'resnet50_gn',\n", + " 'resnet50_mlp',\n", + " 'resnet50c',\n", + " 'resnet50d',\n", + " 'resnet50s',\n", + " 'resnet50t',\n", + " 'resnet50x4_clip',\n", + " 'resnet50x4_clip_gap',\n", + " 'resnet50x16_clip',\n", + " 'resnet50x16_clip_gap',\n", + " 'resnet50x64_clip',\n", + " 'resnet50x64_clip_gap',\n", + " 'resnet51q',\n", + " 'resnet61q',\n", + " 'resnet101',\n", + " 'resnet101_clip',\n", + " 'resnet101_clip_gap',\n", + " 'resnet101c',\n", + " 'resnet101d',\n", + " 'resnet101s',\n", + " 'resnet152',\n", + " 'resnet152c',\n", + " 'resnet152d',\n", + " 'resnet152s',\n", + " 'resnet200',\n", + " 'resnet200d',\n", + " 'resnetaa34d',\n", + " 'resnetaa50',\n", + " 'resnetaa50d',\n", + " 'resnetaa101d',\n", + " 'resnetblur18',\n", + " 'resnetblur50',\n", + " 'resnetblur50d',\n", + " 'resnetblur101d',\n", + " 'resnetrs50',\n", + " 'resnetrs101',\n", + " 'resnetrs152',\n", + " 'resnetrs200',\n", + " 'resnetrs270',\n", + " 'resnetrs350',\n", + " 'resnetrs420',\n", + " 'resnetv2_18',\n", + " 'resnetv2_18d',\n", + " 'resnetv2_34',\n", + " 'resnetv2_34d',\n", + " 'resnetv2_50',\n", + " 'resnetv2_50d',\n", + " 'resnetv2_50d_evos',\n", + " 'resnetv2_50d_frn',\n", + " 'resnetv2_50d_gn',\n", + " 'resnetv2_50t',\n", + " 'resnetv2_50x1_bit',\n", + " 'resnetv2_50x3_bit',\n", + " 'resnetv2_101',\n", + " 'resnetv2_101d',\n", + " 'resnetv2_101x1_bit',\n", + " 'resnetv2_101x3_bit',\n", + " 'resnetv2_152',\n", + " 'resnetv2_152d',\n", + " 'resnetv2_152x2_bit',\n", + " 'resnetv2_152x4_bit',\n", + " 'resnext26ts',\n", + " 'resnext50_32x4d',\n", + " 'resnext50d_32x4d',\n", + " 'resnext101_32x4d',\n", + " 'resnext101_32x8d',\n", + " 'resnext101_32x16d',\n", + " 'resnext101_32x32d',\n", + " 'resnext101_64x4d',\n", + " 'rexnet_100',\n", + " 'rexnet_130',\n", + " 'rexnet_150',\n", + " 'rexnet_200',\n", + " 'rexnet_300',\n", + " 'rexnetr_100',\n", + " 'rexnetr_130',\n", + " 'rexnetr_150',\n", + " 'rexnetr_200',\n", + " 'rexnetr_300',\n", + " 'sam2_hiera_base_plus',\n", + " 'sam2_hiera_large',\n", + " 'sam2_hiera_small',\n", + " 'sam2_hiera_tiny',\n", + " 'samvit_base_patch16',\n", + " 'samvit_base_patch16_224',\n", + " 'samvit_huge_patch16',\n", + " 'samvit_large_patch16',\n", + " 'sebotnet33ts_256',\n", + " 'sedarknet21',\n", + " 'sehalonet33ts',\n", + " 'selecsls42',\n", + " 'selecsls42b',\n", + " 'selecsls60',\n", + " 'selecsls60b',\n", + " 'selecsls84',\n", + " 'semnasnet_050',\n", + " 'semnasnet_075',\n", + " 'semnasnet_100',\n", + " 'semnasnet_140',\n", + " 'senet154',\n", + " 'sequencer2d_l',\n", + " 'sequencer2d_m',\n", + " 'sequencer2d_s',\n", + " 'seresnet18',\n", + " 'seresnet33ts',\n", + " 'seresnet34',\n", + " 'seresnet50',\n", + " 'seresnet50t',\n", + " 'seresnet101',\n", + " 'seresnet152',\n", + " 'seresnet152d',\n", + " 'seresnet200d',\n", + " 'seresnet269d',\n", + " 'seresnetaa50d',\n", + " 'seresnext26d_32x4d',\n", + " 'seresnext26t_32x4d',\n", + " 'seresnext26ts',\n", + " 'seresnext50_32x4d',\n", + " 'seresnext101_32x4d',\n", + " 'seresnext101_32x8d',\n", + " 'seresnext101_64x4d',\n", + " 'seresnext101d_32x8d',\n", + " 'seresnextaa101d_32x8d',\n", + " 'seresnextaa201d_32x8d',\n", + " 'shvit_s1',\n", + " 'shvit_s2',\n", + " 'shvit_s3',\n", + " 'shvit_s4',\n", + " 'skresnet18',\n", + " 'skresnet34',\n", + " 'skresnet50',\n", + " 'skresnet50d',\n", + " 'skresnext50_32x4d',\n", + " 'spnasnet_100',\n", + " 'starnet_s1',\n", + " 'starnet_s2',\n", + " 'starnet_s3',\n", + " 'starnet_s4',\n", + " 'starnet_s050',\n", + " 'starnet_s100',\n", + " 'starnet_s150',\n", + " 'swiftformer_l1',\n", + " 'swiftformer_l3',\n", + " 'swiftformer_s',\n", + " 'swiftformer_xs',\n", + " 'swin_base_patch4_window7_224',\n", + " 'swin_base_patch4_window12_384',\n", + " 'swin_large_patch4_window7_224',\n", + " 'swin_large_patch4_window12_384',\n", + " 'swin_s3_base_224',\n", + " 'swin_s3_small_224',\n", + " 'swin_s3_tiny_224',\n", + " 'swin_small_patch4_window7_224',\n", + " 'swin_tiny_patch4_window7_224',\n", + " 'swinv2_base_window8_256',\n", + " 'swinv2_base_window12_192',\n", + " 'swinv2_base_window12to16_192to256',\n", + " 'swinv2_base_window12to24_192to384',\n", + " 'swinv2_base_window16_256',\n", + " 'swinv2_cr_base_224',\n", + " 'swinv2_cr_base_384',\n", + " 'swinv2_cr_base_ns_224',\n", + " 'swinv2_cr_giant_224',\n", + " 'swinv2_cr_giant_384',\n", + " 'swinv2_cr_huge_224',\n", + " 'swinv2_cr_huge_384',\n", + " 'swinv2_cr_large_224',\n", + " 'swinv2_cr_large_384',\n", + " 'swinv2_cr_small_224',\n", + " 'swinv2_cr_small_384',\n", + " 'swinv2_cr_small_ns_224',\n", + " 'swinv2_cr_small_ns_256',\n", + " 'swinv2_cr_tiny_224',\n", + " 'swinv2_cr_tiny_384',\n", + " 'swinv2_cr_tiny_ns_224',\n", + " 'swinv2_large_window12_192',\n", + " 'swinv2_large_window12to16_192to256',\n", + " 'swinv2_large_window12to24_192to384',\n", + " 'swinv2_small_window8_256',\n", + " 'swinv2_small_window16_256',\n", + " 'swinv2_tiny_window8_256',\n", + " 'swinv2_tiny_window16_256',\n", + " 'test_byobnet',\n", + " 'test_convnext',\n", + " 'test_convnext2',\n", + " 'test_convnext3',\n", + " 'test_efficientnet',\n", + " 'test_efficientnet_evos',\n", + " 'test_efficientnet_gn',\n", + " 'test_efficientnet_ln',\n", + " 'test_mambaout',\n", + " 'test_nfnet',\n", + " 'test_resnet',\n", + " 'test_vit',\n", + " 'test_vit2',\n", + " 'test_vit3',\n", + " 'test_vit4',\n", + " 'tf_efficientnet_b0',\n", + " 'tf_efficientnet_b1',\n", + " 'tf_efficientnet_b2',\n", + " 'tf_efficientnet_b3',\n", + " 'tf_efficientnet_b4',\n", + " 'tf_efficientnet_b5',\n", + " 'tf_efficientnet_b6',\n", + " 'tf_efficientnet_b7',\n", + " 'tf_efficientnet_b8',\n", + " 'tf_efficientnet_cc_b0_4e',\n", + " 'tf_efficientnet_cc_b0_8e',\n", + " 'tf_efficientnet_cc_b1_8e',\n", + " 'tf_efficientnet_el',\n", + " 'tf_efficientnet_em',\n", + " 'tf_efficientnet_es',\n", + " 'tf_efficientnet_l2',\n", + " 'tf_efficientnet_lite0',\n", + " 'tf_efficientnet_lite1',\n", + " 'tf_efficientnet_lite2',\n", + " 'tf_efficientnet_lite3',\n", + " 'tf_efficientnet_lite4',\n", + " 'tf_efficientnetv2_b0',\n", + " 'tf_efficientnetv2_b1',\n", + " 'tf_efficientnetv2_b2',\n", + " 'tf_efficientnetv2_b3',\n", + " 'tf_efficientnetv2_l',\n", + " 'tf_efficientnetv2_m',\n", + " 'tf_efficientnetv2_s',\n", + " 'tf_efficientnetv2_xl',\n", + " 'tf_mixnet_l',\n", + " 'tf_mixnet_m',\n", + " 'tf_mixnet_s',\n", + " 'tf_mobilenetv3_large_075',\n", + " 'tf_mobilenetv3_large_100',\n", + " 'tf_mobilenetv3_large_minimal_100',\n", + " 'tf_mobilenetv3_small_075',\n", + " 'tf_mobilenetv3_small_100',\n", + " 'tf_mobilenetv3_small_minimal_100',\n", + " 'tiny_vit_5m_224',\n", + " 'tiny_vit_11m_224',\n", + " 'tiny_vit_21m_224',\n", + " 'tiny_vit_21m_384',\n", + " 'tiny_vit_21m_512',\n", + " 'tinynet_a',\n", + " 'tinynet_b',\n", + " 'tinynet_c',\n", + " ...]" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "timm.list_models()" + ] + }, + { + "cell_type": "markdown", + "id": "e965b948", + "metadata": {}, + "source": [ + "# Test Inferance" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "d72de444", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "652d208732db40d08a4f1ca5565ed0eb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/31537 [00:00) tensor(5.7054, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.6463, -11.6919, -12.0322, -9.4530, -15.9026, 8.7405],\n", + " [-18.7760, -18.9497, -19.5558, -13.0709, -24.5872, 13.3310],\n", + " [-17.6199, -15.1691, -16.6451, -13.0334, -22.5385, 12.1574],\n", + " [ -9.9132, -10.7750, -12.4069, -8.6010, -15.7670, 8.2736],\n", + " [-15.9893, -15.6640, -14.2749, -10.5522, -20.2389, 11.0540]],\n", + " device='cuda:0', grad_fn=) tensor(4.2167, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.7153, -19.8876, -22.0921, -15.6335, -25.7065, 17.0069],\n", + " [-15.3676, -14.2467, -15.2342, -9.8990, -18.9938, 9.4647],\n", + " [-17.9792, -15.3863, -18.0587, -13.0819, -21.3875, 14.6569],\n", + " [-18.6688, -14.4102, -19.5750, -11.4955, -21.4084, 12.8969],\n", + " [-20.7088, -16.0413, -19.5438, -11.7129, -20.8527, 13.2308]],\n", + " device='cuda:0', grad_fn=) tensor(5.4232, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.2884, -13.2256, -19.0231, -10.5247, -19.4325, 12.5690],\n", + " [-16.4247, -14.8005, -16.2628, -7.1007, -17.1003, 9.8032],\n", + " [-16.3833, -13.2512, -17.6678, -11.7785, -19.4521, 11.2706],\n", + " [-17.9852, -16.8564, -19.0888, -14.4766, -21.9142, 14.0923],\n", + " [-22.7285, -20.0488, -23.6423, -16.3384, -26.4886, 16.9303]],\n", + " device='cuda:0', grad_fn=) tensor(5.2159, device='cuda:0',\n", + " grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_817618/3868621331.py:10: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments.\n", + " results.append(np.array(torch.argmax(out[\"logits\"], dim=1).cpu()))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-25.6627, -19.1187, -26.0079, -17.6518, -27.8401, 17.9818],\n", + " [-20.4639, -14.6373, -21.0470, -13.1323, -21.8942, 14.0545],\n", + " [-17.6472, -12.9593, -16.8312, -13.0439, -17.8351, 12.1448],\n", + " [-24.5938, -19.4026, -23.6713, -15.1207, -27.0474, 16.7659],\n", + " [-23.4175, -18.7242, -20.0196, -14.1604, -26.2955, 15.2203]],\n", + " device='cuda:0', grad_fn=) tensor(6.2651, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.5956, -18.4819, -21.1396, -15.2815, -24.7799, 15.8402],\n", + " [-22.8801, -18.1315, -20.3435, -14.7207, -23.3664, 14.7429],\n", + " [-25.1979, -21.6955, -26.2941, -18.0929, -27.1179, 17.7363],\n", + " [-27.9811, -22.1019, -29.0579, -17.6141, -32.1800, 18.6680],\n", + " [-22.1971, -18.7451, -21.6321, -15.6149, -19.7802, 13.7537]],\n", + " device='cuda:0', grad_fn=) tensor(6.7198, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-21.9624, -18.6315, -23.3189, -14.7934, -25.5468, 14.9872],\n", + " [-21.1179, -18.9034, -21.8123, -12.0546, -24.4166, 13.5688],\n", + " [-24.9073, -20.3081, -24.3183, -15.9231, -26.2028, 17.9872],\n", + " [-12.4258, -10.7020, -13.1045, -7.9197, -5.1201, 2.4479],\n", + " [-18.6397, -16.2914, -19.6577, -11.1818, -22.3971, 12.8716]],\n", + " device='cuda:0', grad_fn=) tensor(5.3668, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.2335, -11.0925, -18.6173, -10.8272, -20.2059, 11.4161],\n", + " [-12.7528, -15.1303, -13.3891, -10.4232, -18.1346, 11.0176],\n", + " [-23.3070, -20.2974, -24.0004, -15.1481, -28.3255, 17.0426],\n", + " [-19.4940, -16.0547, -19.0815, -11.1551, -21.8457, 13.9961],\n", + " [-16.6287, -13.2187, -16.9061, -10.3401, -20.6072, 11.5725]],\n", + " device='cuda:0', grad_fn=) tensor(5.1154, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.0606, -15.3820, -16.0477, -10.4870, -19.0424, 10.9241],\n", + " [-18.9815, -16.4758, -21.5607, -12.1618, -24.4382, 13.6512],\n", + " [-13.5848, -10.5080, -15.7932, -9.5371, -18.0188, 9.1288],\n", + " [-18.6177, -14.1626, -21.8735, -10.3117, -22.8466, 12.2248],\n", + " [-17.6302, -15.2599, -17.4120, -11.6195, -18.0946, 12.4481]],\n", + " device='cuda:0', grad_fn=) tensor(4.7751, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-27.1740, -22.7134, -27.7187, -18.0239, -26.7234, 18.3074],\n", + " [-23.9139, -19.1329, -23.9814, -16.8671, -25.8635, 17.0360],\n", + " [-16.0124, -13.2089, -15.8934, -13.4931, -17.3414, 11.5393],\n", + " [-19.4569, -11.7860, -18.7430, -13.7046, -20.8822, 13.4296],\n", + " [-11.7545, -11.9697, -16.6999, -11.4972, -15.6422, 9.9268]],\n", + " device='cuda:0', grad_fn=) tensor(5.6184, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.0108, -16.9760, -21.6040, -14.2101, -23.7109, 15.8287],\n", + " [-19.1937, -13.8363, -18.0534, -13.2754, -19.3030, 12.3543],\n", + " [-14.8210, -11.6121, -14.3339, -11.0961, -18.6467, 10.4032],\n", + " [ -8.4009, -7.1207, -7.5226, -4.5164, -10.0468, 3.8913],\n", + " [-18.1036, -15.5290, -19.4369, -14.2597, -17.3314, 12.5013]],\n", + " device='cuda:0', grad_fn=) tensor(4.5181, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.3597, -15.4303, -17.1589, -12.9287, -13.9939, 11.7821],\n", + " [ -9.2004, -10.8517, -14.1135, -8.7450, -14.4098, 5.5793],\n", + " [-15.1323, -14.4030, -16.4938, -10.2227, -17.9987, 11.3838],\n", + " [-20.6831, -19.5578, -20.6505, -14.3506, -22.0386, 14.3788],\n", + " [-17.6160, -14.2330, -19.2402, -10.0904, -19.4387, 10.8952]],\n", + " device='cuda:0', grad_fn=) tensor(4.4672, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-23.3211, -18.3509, -23.9018, -14.1461, -26.0910, 15.0205],\n", + " [-13.2699, -12.6511, -13.3758, -9.7168, -14.7329, 9.0603],\n", + " [-22.0040, -19.1629, -21.3595, -12.0876, -23.2405, 13.2916],\n", + " [-24.0177, -20.6618, -22.4069, -12.7104, -25.9310, 13.9082],\n", + " [-19.7511, -17.3349, -18.1201, -12.2846, -19.2682, 12.4322]],\n", + " device='cuda:0', grad_fn=) tensor(5.5359, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -0.5568, -5.7489, -6.0215, -6.2858, -8.5979, -0.1263],\n", + " [-10.3621, -9.7913, -12.4371, 6.7582, -15.3110, -4.8748],\n", + " [-15.9598, -13.0867, -17.8066, -9.5276, -21.2169, 10.0091],\n", + " [-10.8876, -9.3344, -9.6711, -3.6716, -13.1439, 4.2599],\n", + " [ -3.8170, -4.2192, -8.4727, -3.1024, -8.9307, -0.6347]],\n", + " device='cuda:0', grad_fn=) tensor(2.1419, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.6291, -14.5448, -17.2136, -9.1068, -18.4519, 10.2821],\n", + " [-17.0967, -16.7530, -17.0745, -10.1121, -20.0521, 11.6232],\n", + " [-18.6936, -13.5111, -19.3041, -9.9835, -21.2757, 11.5052],\n", + " [-21.4557, -13.5039, -23.8494, -9.9800, -23.7989, 8.7664],\n", + " [-15.7949, -11.9915, -15.0945, -9.9108, -18.1268, 10.7580]],\n", + " device='cuda:0', grad_fn=) tensor(4.7869, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-21.5734, -17.8346, -23.4718, -12.6840, -25.0353, 14.2529],\n", + " [-15.4050, -13.2808, -17.2387, -10.4663, -18.8073, 10.8295],\n", + " [-20.1910, -17.1647, -22.8864, -12.8411, -24.1407, 14.3908],\n", + " [ -9.7707, -11.4383, -10.8691, -2.9999, -12.1758, 5.1991],\n", + " [-15.2149, -12.4283, -15.3228, -9.3511, -18.1027, 11.5676]],\n", + " device='cuda:0', grad_fn=) tensor(4.6150, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.4720, -11.7171, -16.2725, -10.7464, -18.5543, 10.5050],\n", + " [-18.7774, -14.9434, -22.5504, -12.5637, -22.8571, 14.0289],\n", + " [-16.2015, -11.6725, -16.4397, -10.3708, -18.3610, 10.4412],\n", + " [-18.6445, -14.1896, -20.9501, -13.7864, -23.1979, 14.1335],\n", + " [-14.3730, -10.4029, -16.3387, -9.2232, -18.2504, 10.2325]],\n", + " device='cuda:0', grad_fn=) tensor(4.7270, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -8.7186, -3.8538, -8.9463, -2.5103, -8.8867, 0.5764],\n", + " [-10.2975, -8.5321, -11.2477, -3.7862, -12.9197, 3.4381],\n", + " [-18.2361, -13.5092, -18.9623, -11.2675, -19.4533, 11.1002],\n", + " [-14.5296, -16.7200, -17.5984, 9.4304, -20.5124, -7.2371],\n", + " [-13.5946, -7.9236, -14.6383, -7.1874, -14.5220, 7.2951]],\n", + " device='cuda:0', grad_fn=) tensor(3.2606, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.2059, -13.2659, -13.7858, -8.5993, -15.4528, 9.3710],\n", + " [-23.1780, -21.0078, -23.8441, -16.4174, -24.6529, 15.7785],\n", + " [-21.8683, -19.6224, -21.2988, -14.3918, -24.7637, 14.9617],\n", + " [-25.4250, -21.2676, -26.0800, -15.8489, -28.4327, 17.8570],\n", + " [-12.0469, -10.4457, -15.7378, 11.8434, -19.0611, -11.2000]],\n", + " device='cuda:0', grad_fn=) tensor(5.5179, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -2.0269, -6.1973, -8.4963, -1.5625, -8.7724, -0.9863],\n", + " [-16.5860, -11.7183, -16.2914, -9.3547, -15.8960, 9.5831],\n", + " [ -8.7699, -9.6847, -9.3823, 1.5408, -9.4744, -1.2493],\n", + " [-10.8589, -6.7981, -7.5986, -7.6300, -11.8273, 4.4653],\n", + " [-22.4308, -17.9468, -22.7414, -12.8748, -24.2002, 13.6484]],\n", + " device='cuda:0', grad_fn=) tensor(3.0335, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.6758, -11.6944, -12.8267, -4.8349, -14.9143, 7.1880],\n", + " [-17.4477, -12.3160, -18.5011, -5.0701, -16.5591, 5.0563],\n", + " [-14.9894, -13.0718, -17.3287, -11.7971, -16.9851, 11.3770],\n", + " [-11.8078, -10.0448, -12.9544, -7.4926, -14.7328, 8.8941],\n", + " [-15.3958, -13.9130, -16.3935, -8.6360, -17.4381, 10.5895]],\n", + " device='cuda:0', grad_fn=) tensor(3.8815, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.5499, -15.5046, -19.7246, -8.8880, -22.9716, 11.3232],\n", + " [-15.8465, -13.5015, -15.6334, -11.6366, -15.8002, 11.3872],\n", + " [-23.9741, -20.7196, -24.7571, -14.9094, -27.2466, 17.6403],\n", + " [-20.3888, -18.1561, -21.1640, -14.1285, -23.7475, 15.0971],\n", + " [-12.1707, -9.5369, -11.9994, -8.1814, -12.0903, 9.1787]],\n", + " device='cuda:0', grad_fn=) tensor(5.1519, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.0370, -15.6951, -19.2186, -11.8599, -22.1417, 13.2959],\n", + " [-21.9732, -13.7963, -21.3215, -12.2286, -24.0093, 13.1995],\n", + " [-14.3827, -12.8521, -16.1328, -9.3755, -19.2373, 9.9412],\n", + " [ -8.0022, -8.7992, -8.0420, -1.2154, -10.9254, 1.7340],\n", + " [ -6.9496, -3.6115, -10.8233, 5.7642, -10.7271, -7.3332]],\n", + " device='cuda:0', grad_fn=) tensor(3.8245, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -6.7554, -7.0597, -0.5188, -1.4020, -8.4866, -4.5663],\n", + " [-10.3654, -9.4010, -9.8211, 8.1145, -13.0029, -7.4530],\n", + " [-18.1601, -16.8406, -18.7227, -11.4220, -22.2714, 14.0549],\n", + " [-16.6357, -12.7993, -17.6217, -10.5040, -18.1779, 12.7005],\n", + " [-17.6639, -14.5238, -16.8750, -11.1519, -19.1025, 11.6848]],\n", + " device='cuda:0', grad_fn=) tensor(3.8945, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.2514, -14.1619, -21.5101, -13.2114, -21.2256, 12.6599],\n", + " [-18.9201, -17.4807, -18.1632, -12.4249, -19.3518, 14.1659],\n", + " [-13.6821, -8.9911, -15.3710, -10.3057, -17.3289, 7.6996],\n", + " [-13.7279, -15.6886, -15.3964, -9.7234, -17.6351, 11.8476],\n", + " [-16.0148, -12.6544, -15.3582, -10.5735, -14.4972, 10.5639]],\n", + " device='cuda:0', grad_fn=) tensor(4.6511, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.9476, -10.9807, -13.1031, -6.1276, -13.7551, 8.2257],\n", + " [-11.6462, -10.5670, -11.0535, -6.0647, -9.0851, 6.7473],\n", + " [-19.4379, -16.9751, -20.5273, -11.1873, -24.7015, 13.0367],\n", + " [-20.7668, -17.2671, -22.8714, -13.8839, -26.0097, 15.3822],\n", + " [-19.0570, -14.3360, -20.7219, -11.2066, -23.1379, 11.5094]],\n", + " device='cuda:0', grad_fn=) tensor(4.6254, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.3763, -16.9047, -18.2432, -13.8114, -22.4242, 13.9286],\n", + " [-14.7076, -12.3572, -18.3071, -10.6855, -19.9058, 11.4350],\n", + " [-18.1030, -16.7649, -18.3708, -12.5672, -22.5378, 14.4404],\n", + " [-15.3892, -12.1223, -14.2400, -8.7707, -16.4338, 9.5381],\n", + " [-14.6568, -10.0112, -14.0750, -8.8746, -15.2421, 9.4534]],\n", + " device='cuda:0', grad_fn=) tensor(4.6676, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.4078, -17.8550, -20.6478, -14.4612, -23.0387, 15.9179],\n", + " [-13.3120, -9.8104, -12.5739, -10.1389, -16.2885, 10.0250],\n", + " [-20.0272, -17.4726, -22.0632, -11.8177, -23.3084, 13.4900],\n", + " [-15.0457, -12.4919, -15.7207, -8.3389, -17.5007, 9.6556],\n", + " [-15.3929, -12.7397, -15.3918, -9.7167, -17.0744, 9.1469]],\n", + " device='cuda:0', grad_fn=) tensor(4.7141, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-21.4495, -14.4396, -23.4651, -12.4692, -25.5561, 13.2279],\n", + " [-15.3237, -14.3198, -16.5649, -10.8595, -20.9813, 10.7648],\n", + " [-12.9336, -9.1191, -13.1205, -9.8584, -15.9879, 9.0995],\n", + " [-25.2751, -21.7327, -25.3630, -15.0115, -29.2816, 17.0471],\n", + " [-25.0514, -22.7729, -25.8643, -16.6358, -29.2843, 18.1631]],\n", + " device='cuda:0', grad_fn=) tensor(5.6112, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.0194, -12.0820, -16.9358, -8.5514, -19.4512, 10.4406],\n", + " [-18.7939, -13.3192, -18.1140, -12.2072, -21.1129, 12.7919],\n", + " [-12.5711, -9.5236, -12.8724, -7.7499, -12.7376, 7.8732],\n", + " [-20.3983, -15.0184, -22.3984, -14.6905, -23.4612, 14.3706],\n", + " [-17.7230, -9.5240, -19.2937, -11.6319, -18.9726, 11.3751]],\n", + " device='cuda:0', grad_fn=) tensor(4.7453, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.1723, -16.2675, -22.2052, -13.0983, -24.3365, 14.1456],\n", + " [-11.0697, -10.0703, -13.1738, -4.8186, -14.3974, 6.8464],\n", + " [-11.2727, 7.5081, -13.0219, -8.6593, -12.1459, -6.8008],\n", + " [-21.6858, -17.4800, -23.2178, -11.7569, -23.6457, 13.6635],\n", + " [-17.5239, -15.7766, -19.1453, -12.8879, -20.5150, 14.7840]],\n", + " device='cuda:0', grad_fn=) tensor(4.5894, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-11.8716, -11.6161, -14.2173, -5.5436, -13.3707, 7.0785],\n", + " [-13.5630, -10.6320, -13.3879, -9.0720, -13.2748, 9.1122],\n", + " [-18.8632, -15.0208, -18.3854, -13.8823, -20.2418, 13.6420],\n", + " [-21.1739, -16.9934, -23.7775, -14.9724, -24.1337, 15.7189],\n", + " [-17.9503, -13.8916, -18.0888, -12.2282, -19.3619, 13.0755]],\n", + " device='cuda:0', grad_fn=) tensor(4.7351, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.3899, -4.0264, -12.1425, -8.2236, -13.6513, 5.4770],\n", + " [-13.4120, -8.9059, -15.7140, -9.5008, -16.5135, 9.1395],\n", + " [-16.1756, -14.4034, -15.9220, -10.1733, -18.8613, 11.8756],\n", + " [-19.3965, -18.8583, -20.2657, -12.2348, -24.9079, 14.2655],\n", + " [-13.7035, -11.6843, -13.8681, -8.0009, -15.3420, 8.7375]],\n", + " device='cuda:0', grad_fn=) tensor(4.0865, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.7125, -11.5141, -20.1644, -11.2324, -19.4652, 10.3156],\n", + " [-13.7782, -10.2003, -15.1902, -8.1787, -15.0809, 10.4336],\n", + " [-14.4018, -13.0009, -14.9035, -9.9122, -16.2714, 11.1745],\n", + " [ -9.4523, -12.1932, -13.5959, -8.6677, -15.6577, 8.3429],\n", + " [-12.4923, -11.9569, -16.6663, -11.0034, -18.0326, 11.2764]],\n", + " device='cuda:0', grad_fn=) tensor(3.9127, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.5043, -17.3827, -22.0365, -13.3728, -24.3745, 15.4590],\n", + " [-16.2307, -12.9686, -18.9715, -12.1647, -19.6928, 12.1505],\n", + " [-16.7359, -15.6334, -19.2490, -11.3302, -21.1625, 13.3982],\n", + " [-19.3899, -18.3584, -21.5848, -14.8992, -22.1001, 15.8507],\n", + " [-19.3322, -17.2064, -19.7197, -12.2162, -22.9659, 12.8671]],\n", + " device='cuda:0', grad_fn=) tensor(5.3973, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.1309, -11.2604, -16.7232, -10.6595, -18.3610, 11.5822],\n", + " [-21.1509, -17.8577, -22.2355, -14.8425, -24.8438, 15.6943],\n", + " [-21.0714, -14.6085, -22.0803, -13.0932, -24.5670, 14.7513],\n", + " [-15.6899, -10.5828, -17.3273, -9.8969, -18.9134, 10.8317],\n", + " [-21.1024, -17.6348, -23.0049, -15.6890, -25.8769, 16.1231]],\n", + " device='cuda:0', grad_fn=) tensor(5.4043, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.7314, -10.7975, -14.8153, -9.2098, -15.8382, 9.7865],\n", + " [-16.6813, -14.4243, -16.8990, -9.1762, -18.8209, 10.4882],\n", + " [-23.9544, -18.8993, -23.8129, -15.6315, -26.5117, 16.6952],\n", + " [-16.1784, -12.9683, -19.9735, -10.5396, -20.8684, 10.9172],\n", + " [ -5.0827, -7.2716, -9.1614, -3.5863, -11.1548, 2.8657]],\n", + " device='cuda:0', grad_fn=) tensor(4.2490, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.5082, -11.0677, -17.9999, -7.9218, -19.9426, 9.0200],\n", + " [-16.0838, -13.3417, -18.7227, -10.9147, -20.9888, 11.8417],\n", + " [-18.6269, -14.0471, -18.8186, -12.1918, -22.2680, 12.7281],\n", + " [-10.7434, -12.5120, -13.7637, 1.5449, -14.9066, -0.4424],\n", + " [-11.3387, -10.9150, -14.1712, -8.4411, -13.7020, 7.9346]],\n", + " device='cuda:0', grad_fn=) tensor(3.8687, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-23.9070, -20.3559, -24.4889, -15.8258, -28.5338, 18.2317],\n", + " [-11.3584, -12.1919, -11.6068, 11.2247, -16.6001, -14.1959],\n", + " [-14.2405, -11.6127, -16.6036, -6.6192, -16.4254, 6.5353],\n", + " [-17.1072, -13.2204, -18.7735, -12.2513, -20.6415, 12.3282],\n", + " [ 2.6260, -7.5613, -5.9034, -4.3944, -9.4653, -3.6368]],\n", + " device='cuda:0', grad_fn=) tensor(3.8349, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.2962, -11.8773, -14.5696, -11.7493, -16.9390, 11.1837],\n", + " [-17.5687, -15.0570, -18.5007, -11.0198, -21.2345, 13.0276],\n", + " [ -8.2983, -9.1486, -11.4206, 4.7790, -13.1592, -1.7166],\n", + " [ -8.1644, -9.9813, -11.8416, -8.2378, -14.3267, 6.9688],\n", + " [-21.5680, -17.1657, -22.8948, -14.1432, -26.0715, 15.8256]],\n", + " device='cuda:0', grad_fn=) tensor(4.0285, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.2102, -12.9836, -14.8315, -10.8962, -18.8475, 11.6788],\n", + " [-16.6749, -14.2858, -18.9702, -14.3616, -18.9991, 13.4779],\n", + " [-19.5277, -15.9372, -22.4681, -14.7164, -22.8174, 15.5797],\n", + " [-16.2891, -14.4581, -21.1394, -14.0207, -20.1127, 12.9275],\n", + " [-13.4849, -13.2919, -16.1890, -9.3673, -19.0055, 10.3136]],\n", + " device='cuda:0', grad_fn=) tensor(4.9055, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.2771, -8.6008, -11.1654, -5.6105, -12.6183, 7.4107],\n", + " [-19.0082, -14.2938, -19.8347, -12.9445, -21.8790, 13.6905],\n", + " [-19.5861, -17.8861, -20.7971, -13.7361, -22.5391, 14.2205],\n", + " [ -9.9408, -6.6259, -10.1789, -4.8387, -9.7895, 5.4335],\n", + " [-24.2889, -22.7694, -25.1084, -15.3963, -28.4358, 16.8528]],\n", + " device='cuda:0', grad_fn=) tensor(4.6909, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.2590, -12.2529, -13.6483, -7.7208, -14.7364, 8.5842],\n", + " [-13.4446, -12.6274, -15.1673, -8.5549, -17.6365, 10.3967],\n", + " [-23.5282, -21.2204, -23.7261, -16.0989, -28.8771, 17.5752],\n", + " [-25.7697, -23.3372, -25.1170, -16.1030, -27.3975, 16.6778],\n", + " [-24.7674, -19.2580, -24.2328, -15.1823, -27.6261, 16.1211]],\n", + " device='cuda:0', grad_fn=) tensor(5.6375, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.7339, -18.3260, -22.7324, -15.5009, -25.4800, 15.9922],\n", + " [-20.9870, -19.1441, -23.7578, -8.6340, -25.0596, 11.2608],\n", + " [-16.9437, -15.1411, -18.1658, -10.8829, -20.7621, 11.7862],\n", + " [-15.5965, -12.6690, -18.0372, -10.1243, -18.7799, 11.1244],\n", + " [-17.5048, -16.0828, -18.9316, -10.9727, -21.0505, 12.4862]],\n", + " device='cuda:0', grad_fn=) tensor(5.1472, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.1076, -14.6969, -17.6454, -10.5643, -18.4186, 11.0711],\n", + " [-23.8322, -18.0463, -22.9464, -14.8553, -27.5256, 15.5891],\n", + " [-20.9450, -19.2244, -19.7388, -12.8142, -26.1809, 13.4994],\n", + " [-19.4913, -15.1529, -19.6476, -11.8483, -22.7901, 12.7063],\n", + " [-13.4831, -12.1989, -14.2512, -10.3662, -17.7110, 9.4714]],\n", + " device='cuda:0', grad_fn=) tensor(5.2732, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.6361, -15.5828, -18.0406, -11.4602, -20.7875, 12.4264],\n", + " [-22.6797, -17.8540, -24.9452, -12.1069, -26.9300, 14.7086],\n", + " [-12.6837, -10.5632, -13.3694, -6.2251, -14.1930, 8.0056],\n", + " [-14.4769, -14.8849, -16.7415, -7.2062, -17.8976, 9.2677],\n", + " [-15.2257, -13.9795, -17.6534, -10.5351, -19.5998, 11.9774]],\n", + " device='cuda:0', grad_fn=) tensor(4.6364, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.8367, -15.8390, -21.6865, -13.6612, -22.6122, 13.8827],\n", + " [-13.1518, -12.6077, -15.2049, -9.6903, -17.9623, 9.2223],\n", + " [-16.5043, -12.6292, -15.8804, -9.7129, -17.5602, 10.6516],\n", + " [-18.5433, -14.8849, -19.3049, -11.9239, -21.6641, 12.5196],\n", + " [-21.6438, -16.3592, -20.4209, -15.6192, -23.5265, 15.3546]],\n", + " device='cuda:0', grad_fn=) tensor(5.0437, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.9467, -13.5462, -17.9762, -12.1280, -19.3299, 11.7681],\n", + " [-17.0266, -14.3495, -18.9394, -13.4108, -20.9126, 13.7638],\n", + " [-21.4019, -19.1468, -25.2928, -16.9116, -26.3731, 17.1162],\n", + " [-26.3484, -23.6832, -27.5170, -17.2330, -28.0154, 18.2451],\n", + " [-16.0462, -13.2988, -16.9302, -12.2843, -14.9032, 11.5992]],\n", + " device='cuda:0', grad_fn=) tensor(5.5421, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.7290, -15.8293, -21.6415, -13.0583, -24.8633, 14.1536],\n", + " [-17.1478, -12.9604, -18.5224, -9.2067, -19.9271, 11.3396],\n", + " [-20.3198, -17.8568, -19.6014, -13.6527, -19.7684, 13.9269],\n", + " [-18.8919, -17.4777, -18.8643, -12.4781, -23.4350, 13.2882],\n", + " [-14.4659, -15.1607, -18.4798, -11.4604, -19.2335, 11.1741]],\n", + " device='cuda:0', grad_fn=) tensor(5.1812, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.6202, -13.9040, -14.9371, -10.6351, -18.3301, 10.9236],\n", + " [-17.5480, -12.7234, -18.4907, -12.3305, -20.6987, 13.0662],\n", + " [-24.0887, -18.2728, -23.3915, -13.7346, -26.1900, 15.9506],\n", + " [-20.6726, -19.0884, -21.0361, -13.8959, -24.8723, 15.4405],\n", + " [-14.2718, -14.0235, -14.0753, -10.1550, -17.4075, 10.3238]],\n", + " device='cuda:0', grad_fn=) tensor(5.3302, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.6710, -17.2879, -18.2629, -11.3739, -22.1345, 14.0307],\n", + " [-22.0841, -18.4903, -22.5544, -13.2804, -24.1217, 14.5825],\n", + " [-12.6307, -11.1834, -13.0976, -6.1837, -13.2162, 6.7875],\n", + " [-15.5691, -15.8008, -13.1474, -10.4582, -17.9349, 10.1040],\n", + " [-15.1532, -12.8108, -15.1631, -9.5148, -18.1769, 10.6894]],\n", + " device='cuda:0', grad_fn=) tensor(4.7435, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.1744, -11.6438, -15.7907, -9.8355, -17.2297, 12.3719],\n", + " [-11.8090, -10.7504, -10.7309, -9.9834, -12.2582, 8.2313],\n", + " [-14.7942, -11.5056, -14.0296, -10.6525, -16.6547, 11.2133],\n", + " [ -6.0504, -3.3838, -2.6464, -2.6269, -6.5703, 1.2754],\n", + " [-17.6665, -14.7388, -19.0898, -13.3054, -19.9522, 13.9037]],\n", + " device='cuda:0', grad_fn=) tensor(3.7637, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[ -9.6479, -7.9200, -11.8770, -7.9571, -11.9346, 7.5819],\n", + " [-14.0423, -13.3970, -11.3144, -8.6978, -14.2912, 7.6402],\n", + " [-19.8466, -17.1858, -17.6502, -12.6322, -20.5368, 13.1196],\n", + " [-26.9309, -20.4827, -27.0063, -15.4244, -28.0346, 17.7513],\n", + " [-14.3584, -13.3932, -14.8911, -9.2751, -18.0065, 9.8739]],\n", + " device='cuda:0', grad_fn=) tensor(4.6932, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.3330, -9.8182, -18.5348, -11.8453, -18.9183, 10.8044],\n", + " [-15.1557, -13.9038, -15.7074, -10.8582, -18.2332, 11.8055],\n", + " [-18.6996, -14.5197, -20.9992, -12.6252, -20.1872, 13.1674],\n", + " [-22.8469, -19.7574, -24.8654, -15.2763, -26.8487, 17.1085],\n", + " [-10.4986, -8.8861, -10.7186, -5.3894, -13.6085, 6.3734]],\n", + " device='cuda:0', grad_fn=) tensor(4.7266, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.6296, -15.6283, -20.6069, -12.6860, -23.3083, 14.4273],\n", + " [-13.3279, -9.0744, -13.8022, -8.2591, -16.8325, 8.2038],\n", + " [-13.4357, -9.8973, -15.1503, -10.1377, -16.4957, 10.2683],\n", + " [-10.2582, -8.3794, -14.0234, -7.2057, -14.7144, 6.7166],\n", + " [-13.6460, -8.2001, -14.9831, -9.2263, -17.0637, 9.4848]],\n", + " device='cuda:0', grad_fn=) tensor(3.9467, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-11.6963, -9.2469, -12.6617, -8.1706, -14.4895, 9.3311],\n", + " [-20.3952, -18.3361, -21.6297, -13.9980, -25.2609, 15.2594],\n", + " [-22.0856, -18.9060, -20.6000, -14.3744, -24.6768, 15.1909],\n", + " [-15.2122, -10.5237, -15.9923, -11.6820, -18.4171, 10.4842],\n", + " [-13.1434, -11.2104, -9.8275, -10.0713, -15.1064, 8.3303]],\n", + " device='cuda:0', grad_fn=) tensor(4.7043, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.2211, -9.5423, -13.1478, -8.8343, -15.3084, 8.7082],\n", + " [-16.7319, -12.6597, -17.8649, -11.3318, -20.8555, 11.3941],\n", + " [-17.7222, -10.5871, -19.6106, -12.1121, -19.2650, 12.1031],\n", + " [-19.6047, -18.7017, -20.6850, -10.4466, -24.2143, 12.4366],\n", + " [-13.8230, -12.3804, -15.4939, 6.6683, -17.3757, -2.6268]],\n", + " device='cuda:0', grad_fn=) tensor(4.4162, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.1193, -14.8870, -17.5374, -11.4638, -20.5449, 13.0229],\n", + " [-20.6139, -17.4176, -19.7047, -10.7328, -22.7407, 13.2464],\n", + " [-20.6287, -18.4637, -22.4636, -15.8730, -25.1109, 17.0365],\n", + " [-15.1208, -13.3592, -16.8950, -9.7724, -18.1916, 10.8109],\n", + " [-21.5424, -18.0085, -20.5713, -12.5505, -23.4866, 15.6004]],\n", + " device='cuda:0', grad_fn=) tensor(5.5247, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.4129, -18.2680, -19.5686, -12.8061, -24.2072, 14.3643],\n", + " [-24.4469, -17.0240, -26.6052, -14.7471, -27.1959, 15.1867],\n", + " [-20.6170, -19.4698, -20.2167, -12.6614, -23.4191, 14.6381],\n", + " [-18.8323, -16.9009, -22.7492, -14.6131, -22.3041, 15.3515],\n", + " [-19.2088, -17.0492, -21.2708, -12.2217, -23.2812, 14.3415]],\n", + " device='cuda:0', grad_fn=) tensor(5.9133, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.8148, -17.4347, -20.4716, -12.1514, -23.6210, 15.4582],\n", + " [-19.1338, -15.6345, -20.4077, -11.0689, -22.6674, 13.5178],\n", + " [-18.6162, -16.1241, -22.1729, -13.1607, -23.0949, 13.3637],\n", + " [-15.0940, -12.3485, -15.2322, -9.0852, -16.6256, 8.9301],\n", + " [-20.9731, -17.2684, -19.6739, -10.5838, -21.9027, 12.4065]],\n", + " device='cuda:0', grad_fn=) tensor(5.2436, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.5214, -15.9404, -20.7522, -11.6856, -22.1833, 12.5102],\n", + " [-14.5146, -12.0837, -12.8237, -12.1591, -16.6726, 10.1114],\n", + " [-16.6550, -16.0756, -16.9363, -12.7888, -20.6022, 13.5074],\n", + " [-13.6076, -9.4316, -15.9310, -9.9720, -16.3886, 11.5593],\n", + " [-20.2257, -16.8778, -22.4005, -13.2786, -21.4215, 14.5956]],\n", + " device='cuda:0', grad_fn=) tensor(4.8603, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.1366, -17.3678, -22.1653, -13.3608, -22.3511, 15.2668],\n", + " [-11.4665, -11.4157, -12.7001, -7.7738, -12.1872, 7.1619],\n", + " [-18.3742, -13.8976, -18.9590, -9.7601, -20.8011, 10.7781],\n", + " [-13.9690, -6.9894, -13.6376, -7.2607, -13.0903, 5.0189],\n", + " [-27.2542, -21.2988, -29.1481, -18.2303, -31.1168, 19.5048]],\n", + " device='cuda:0', grad_fn=) tensor(4.9647, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.7593, -19.5813, -23.7843, -15.7244, -24.8389, 17.3628],\n", + " [-24.4455, -20.9051, -25.2435, -15.8720, -28.3832, 16.1786],\n", + " [-19.1644, -14.5254, -20.0183, -13.1755, -24.6596, 14.1023],\n", + " [-16.2799, -10.8130, -17.4141, -8.0549, -18.6652, 8.5796],\n", + " [-20.8250, -17.1108, -22.6241, -13.8311, -26.5337, 15.1030]],\n", + " device='cuda:0', grad_fn=) tensor(5.8267, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.8580, -19.5403, -21.5024, -11.0178, -24.2921, 13.4842],\n", + " [-16.8547, -13.1214, -18.1288, -10.4439, -20.0097, 11.6236],\n", + " [-18.8804, -16.6998, -20.2242, -12.0481, -23.1131, 13.1611],\n", + " [-13.9925, -12.4996, -15.5666, -11.0671, -17.2274, 10.1886],\n", + " [-13.9824, -12.9184, -15.4214, -7.4420, -16.6134, 8.4947]],\n", + " device='cuda:0', grad_fn=) tensor(4.6840, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.2970, -14.2517, -17.7940, -11.0901, -20.4699, 12.0047],\n", + " [-22.3773, -17.9124, -24.3739, -13.9360, -27.1090, 16.0533],\n", + " [-16.8401, -15.5732, -19.3255, -12.9068, -22.7579, 12.8500],\n", + " [-16.9204, -15.4857, -19.5650, -11.0775, -21.3156, 12.2616],\n", + " [-18.7898, -11.4660, -22.6002, -13.2158, -23.7073, 12.9592]],\n", + " device='cuda:0', grad_fn=) tensor(5.2451, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.3884, -13.5856, -17.9079, -9.7859, -17.7833, 11.7924],\n", + " [-19.8772, -14.3560, -22.5769, -11.7164, -22.1028, 12.6778],\n", + " [-16.5265, -15.7999, -19.0882, -11.4491, -20.8517, 13.6349],\n", + " [-16.6222, -15.9760, -16.6970, -11.4053, -20.2419, 12.7633],\n", + " [-19.3569, -12.0447, -22.2062, -14.7031, -22.1541, 13.1451]],\n", + " device='cuda:0', grad_fn=) tensor(5.0928, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-32.9812, -27.4990, -33.0810, -19.9676, -37.1122, 22.7010],\n", + " [-25.1412, -22.5083, -26.1246, -17.9399, -30.2328, 19.0551],\n", + " [-18.1044, -17.4468, -16.1489, -12.2034, -21.7271, 13.3791],\n", + " [-22.0484, -18.7400, -23.6264, -15.1139, -25.0772, 16.2556],\n", + " [-21.4765, -19.0986, -21.7936, -14.6659, -25.0793, 16.4497]],\n", + " device='cuda:0', grad_fn=) tensor(6.9197, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.9703, -12.5283, -15.9123, -10.3607, -16.9043, 11.8670],\n", + " [-14.7851, -11.3858, -14.6984, -8.0111, -17.8239, 9.1729],\n", + " [-23.3414, -16.4799, -23.8550, -14.2612, -25.1833, 15.5180],\n", + " [-11.8542, -14.9106, -17.7117, -10.1824, -19.5678, 9.4499],\n", + " [-16.5264, -11.4633, -14.9103, -13.0815, -16.8892, 11.1628]],\n", + " device='cuda:0', grad_fn=) tensor(4.6550, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-24.5753, -19.0693, -25.6013, -16.3616, -29.0606, 16.9855],\n", + " [-18.4498, -16.5574, -19.0798, -12.9496, -22.0195, 15.4385],\n", + " [-17.8592, -15.2445, -20.2465, -10.8744, -21.1017, 12.4335],\n", + " [-16.3447, -16.2394, -20.7919, -11.0883, -22.6917, 11.9366],\n", + " [-20.8857, -16.6305, -22.3670, -13.4147, -25.3281, 15.1691]],\n", + " device='cuda:0', grad_fn=) tensor(5.6693, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.3107, -17.2678, -24.4074, -15.4727, -27.0266, 16.8336],\n", + " [-18.4084, -14.8985, -19.2755, -12.0711, -21.3660, 14.0156],\n", + " [-18.7173, -18.2618, -20.5120, -13.1939, -24.9243, 13.7050],\n", + " [-18.5922, -13.9030, -18.8118, -10.7443, -20.9841, 11.6499],\n", + " [-18.7843, -15.1183, -20.3662, -9.4882, -22.0089, 11.7991]],\n", + " device='cuda:0', grad_fn=) tensor(5.4939, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.5255, -12.5962, -17.4335, -11.1467, -18.9626, 11.4310],\n", + " [-18.3658, -15.5181, -19.8912, -12.0178, -22.4920, 14.0607],\n", + " [-17.5977, -14.6222, -19.6648, -11.2494, -21.4487, 12.7164],\n", + " [-10.3094, -10.8707, -10.0709, -7.6031, -13.0766, 7.3406],\n", + " [-18.6899, -15.6591, -20.4259, -11.2657, -22.5217, 13.6514]],\n", + " device='cuda:0', grad_fn=) tensor(4.6230, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.0740, -17.1642, -19.3914, -11.9442, -23.4112, 13.1909],\n", + " [-17.8197, -14.5614, -19.1308, -11.5887, -20.4041, 12.1787],\n", + " [-19.0360, -13.7985, -18.8315, -12.2055, -22.0558, 12.9905],\n", + " [-16.9382, -12.0843, -19.4143, -9.3930, -20.9384, 10.6815],\n", + " [-17.4838, -15.3408, -20.2609, -9.5611, -21.3227, 11.1091]],\n", + " device='cuda:0', grad_fn=) tensor(4.9834, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.6968, -11.7862, -13.2379, -6.2418, -14.6098, 7.1439],\n", + " [-16.5674, -15.7555, -18.6584, -8.0537, -20.0214, 11.4463],\n", + " [-12.7221, -9.2729, -17.5542, -10.4586, -17.9921, 6.7436],\n", + " [-21.6593, -19.3343, -23.6112, -12.6774, -27.3476, 14.4801],\n", + " [-21.3644, -17.6912, -22.5517, -14.6001, -26.5790, 15.4281]],\n", + " device='cuda:0', grad_fn=) tensor(4.6085, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.8990, -17.7780, -21.0489, -13.2314, -25.6239, 14.9066],\n", + " [-19.5994, -18.4150, -22.4583, -11.8884, -24.9489, 14.8532],\n", + " [-19.2462, -17.4901, -21.9835, -10.9877, -23.2811, 13.5844],\n", + " [-15.5826, -15.2329, -19.3137, -9.8560, -20.8139, 10.5707],\n", + " [-14.5252, -10.7643, -17.0969, -8.3940, -17.8676, 9.7440]],\n", + " device='cuda:0', grad_fn=) tensor(5.1171, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-23.9719, -20.1234, -28.4602, -17.6994, -28.3996, 17.4066],\n", + " [-18.7809, -14.9155, -19.8409, -11.6401, -21.6021, 13.9610],\n", + " [-18.5305, -14.9233, -19.0456, -13.4152, -20.3604, 12.9150],\n", + " [-20.7417, -18.2628, -21.6057, -15.5778, -23.2930, 16.6837],\n", + " [-17.9475, -14.1712, -19.6554, -13.3653, -21.2943, 13.8216]],\n", + " device='cuda:0', grad_fn=) tensor(5.8254, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.7057, -18.9001, -19.3759, -13.4884, -24.9231, 14.8078],\n", + " [-16.1870, -12.3184, -15.6544, -11.0296, -18.6483, 10.9616],\n", + " [-15.8026, -11.9614, -16.7288, -12.8908, -19.2097, 11.4760],\n", + " [ -9.8667, -12.8459, -12.7908, -8.4372, -15.7941, 8.5906],\n", + " [-18.6315, -15.2141, -21.0891, -13.1284, -23.1672, 13.2186]],\n", + " device='cuda:0', grad_fn=) tensor(4.6416, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.3170, -12.1141, -13.8850, -7.9731, -17.3659, 8.5570],\n", + " [-11.4212, -10.2209, -14.5908, -7.4196, -16.2881, 7.7112],\n", + " [ -9.7739, -12.8511, -12.4381, -7.6655, -14.4346, 7.7773],\n", + " [-17.9632, -14.6031, -20.9872, -13.9909, -22.6986, 13.1252],\n", + " [-20.8732, -18.3798, -24.9761, -16.3162, -26.3668, 16.9583]],\n", + " device='cuda:0', grad_fn=) tensor(4.2160, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.7734, -18.7654, -22.4377, -14.1063, -25.6670, 16.9886],\n", + " [-14.9720, -14.9534, -19.4248, -14.3606, -19.9433, 13.3392],\n", + " [-21.3588, -18.4297, -23.3515, -14.0178, -25.4403, 16.4733],\n", + " [-16.7734, -17.7400, -21.6076, -14.7985, -24.4414, 13.4196],\n", + " [-15.9589, -13.4808, -16.9370, -10.4864, -19.4568, 11.5403]],\n", + " device='cuda:0', grad_fn=) tensor(5.4532, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.3548, -16.7785, -18.4801, -12.0774, -22.7151, 13.5736],\n", + " [-18.7703, -19.9252, -19.5998, -12.3849, -25.6287, 13.0342],\n", + " [-15.2746, -14.4263, -16.7372, -11.9905, -18.9224, 12.9142],\n", + " [-17.7992, -15.0942, -18.7344, -11.5380, -22.3793, 13.2218],\n", + " [-18.4698, -15.9423, -20.4555, -12.8131, -22.0360, 14.7222]],\n", + " device='cuda:0', grad_fn=) tensor(5.2045, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.9748, -13.7058, -19.0213, -11.8479, -21.9417, 13.5272],\n", + " [-14.9068, -15.0676, -17.9662, -12.5415, -20.2507, 11.7762],\n", + " [-20.3708, -17.9092, -20.5850, -14.4878, -24.9056, 13.6956],\n", + " [-22.7280, -18.8601, -22.8030, -14.1650, -25.4472, 16.1378],\n", + " [-17.7265, -16.3266, -17.4803, -13.3785, -21.0708, 13.4891]],\n", + " device='cuda:0', grad_fn=) tensor(5.4111, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.8754, -15.7071, -17.7449, -13.0361, -19.9818, 13.8916],\n", + " [-21.3584, -14.6113, -25.3200, -15.9540, -26.0373, 15.1054],\n", + " [-15.2743, -13.0577, -15.4156, -10.3202, -17.6266, 11.2770],\n", + " [-14.2055, -14.5312, -18.7720, -13.2209, -20.5926, 11.5475],\n", + " [-22.4946, -20.2387, -24.5248, -14.6311, -26.9361, 17.1001]],\n", + " device='cuda:0', grad_fn=) tensor(5.3043, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.5130, -13.8763, -15.8584, -10.6193, -18.6628, 11.4881],\n", + " [-20.3973, -16.6515, -20.9321, -12.2857, -22.1069, 13.7251],\n", + " [-20.5365, -19.3787, -22.4803, -10.7939, -24.9822, 14.0701],\n", + " [-20.2452, -16.8394, -19.4366, -13.7091, -23.2141, 14.7916],\n", + " [-19.1439, -17.5626, -20.2701, -14.3621, -24.5424, 15.2705]],\n", + " device='cuda:0', grad_fn=) tensor(5.5060, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.6656, -15.7138, -18.2888, -10.9559, -22.1569, 13.3772],\n", + " [-17.7336, -15.7905, -16.7400, -10.3683, -21.4636, 12.8227],\n", + " [-19.7585, -18.6270, -20.1307, -12.2257, -24.3634, 15.1827],\n", + " [-19.4242, -17.5527, -20.4448, -14.9366, -23.4163, 15.6954],\n", + " [-20.2690, -17.9048, -19.3549, -14.4854, -22.0681, 14.6792]],\n", + " device='cuda:0', grad_fn=) tensor(5.5869, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-28.1177, -20.9715, -27.2392, -17.2305, -31.3013, 17.7767],\n", + " [-19.1474, -17.2928, -19.7323, -11.9047, -22.8305, 12.9609],\n", + " [-17.2956, -15.3368, -17.8013, -11.9221, -20.5031, 14.0885],\n", + " [-12.6312, -12.2924, -13.4635, -10.5112, -16.6389, 10.2560],\n", + " [-19.0983, -17.7811, -21.8155, -11.8038, -22.7976, 12.1057]],\n", + " device='cuda:0', grad_fn=) tensor(5.4493, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.1600, -13.0184, -16.4458, -11.0808, -18.5597, 11.1225],\n", + " [-19.3812, -16.1166, -18.3470, -13.0626, -21.5304, 14.6689],\n", + " [-21.5873, -17.4331, -20.2584, -13.1244, -24.7632, 14.0848],\n", + " [-23.7779, -18.7419, -24.6766, -15.7794, -27.1349, 15.7003],\n", + " [-23.2882, -18.8450, -23.5957, -13.4268, -25.4233, 15.2923]],\n", + " device='cuda:0', grad_fn=) tensor(5.8354, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.9866, -15.4878, -20.1301, -11.2216, -21.4158, 11.9829],\n", + " [-15.9552, -14.6556, -16.4333, -10.1787, -18.7979, 10.3496],\n", + " [-22.7104, -20.2965, -23.0794, -13.0905, -25.6637, 15.5291],\n", + " [-23.5403, -20.4541, -25.9342, -13.8349, -28.3207, 16.9053],\n", + " [-20.9428, -17.7623, -23.1671, -12.4682, -24.1768, 14.2896]],\n", + " device='cuda:0', grad_fn=) tensor(5.6397, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-19.0949, -16.7221, -20.5008, -12.0557, -23.2993, 14.4390],\n", + " [-21.4803, -18.2019, -22.1185, -15.1478, -26.1642, 15.3335],\n", + " [-21.4243, -20.2971, -23.3980, -16.1245, -24.5950, 17.2290],\n", + " [-23.4091, -22.6142, -25.9959, -16.4560, -27.4431, 18.6262],\n", + " [-22.3768, -21.1278, -23.1824, -14.7416, -27.8023, 16.7324]],\n", + " device='cuda:0', grad_fn=) tensor(6.3382, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.3863, -13.5600, -19.9461, -11.2188, -20.9589, 11.4459],\n", + " [-14.6725, -11.3649, -16.0441, -10.0943, -19.2804, 10.0089],\n", + " [-16.7444, -13.4433, -17.6785, -8.1769, -20.1085, 10.2394],\n", + " [ -9.7199, -5.7652, -10.0145, -6.7054, -9.8233, 5.9927],\n", + " [-19.9933, -16.5628, -20.9082, -9.7896, -22.6419, 11.9820]],\n", + " device='cuda:0', grad_fn=) tensor(4.2731, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.0861, -11.9654, -17.5644, -10.2035, -18.9767, 11.8931],\n", + " [-15.9785, -13.4638, -17.9364, -9.6540, -20.1270, 10.0593],\n", + " [-20.1442, -17.3396, -21.4897, -13.7703, -23.7502, 13.8689],\n", + " [-11.9642, -12.9455, -14.1612, -8.1744, -16.9217, 9.6575],\n", + " [-21.4610, -17.9839, -23.1693, -12.4272, -24.6516, 14.0490]],\n", + " device='cuda:0', grad_fn=) tensor(4.8721, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.1201, -20.7091, -23.0904, -13.9579, -26.0676, 15.6776],\n", + " [-14.5174, -14.5780, -16.3154, -10.3883, -17.2480, 10.0580],\n", + " [-14.4431, -11.3692, -16.6122, -13.5323, -19.6046, 11.2442],\n", + " [-27.7304, -24.5485, -28.6611, -19.3034, -31.4497, 20.7093],\n", + " [-22.7288, -18.1551, -24.3351, -16.3573, -26.0221, 17.1175]],\n", + " device='cuda:0', grad_fn=) tensor(5.8115, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.4127, -15.5328, -17.9819, -8.9730, -18.7125, 10.6280],\n", + " [-23.7108, -19.6902, -23.8317, -16.1002, -26.8017, 18.3754],\n", + " [-16.6169, -14.0449, -18.5231, -10.0803, -20.8324, 12.0207],\n", + " [-16.1439, -14.9718, -17.6866, -8.8520, -19.4620, 10.4882],\n", + " [-24.1040, -20.5658, -26.6451, -16.7603, -26.7051, 17.2574]],\n", + " device='cuda:0', grad_fn=) tensor(5.5253, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-18.2382, -14.3147, -18.9593, -12.2502, -19.1783, 12.1915],\n", + " [-14.6757, -12.5675, -16.1349, -7.8402, -16.6600, 8.6357],\n", + " [-19.6062, -15.8615, -20.6970, -9.9591, -24.0498, 10.3431],\n", + " [-14.4842, -13.9449, -16.3620, -4.8229, -17.7464, 6.8868],\n", + " [-16.3014, -13.0761, -17.5243, -10.0741, -17.6989, 11.8242]],\n", + " device='cuda:0', grad_fn=) tensor(4.4399, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-17.6400, -15.7390, -19.3633, -11.4395, -20.4425, 12.6652],\n", + " [-11.4393, -12.9374, -14.6822, -7.5145, -14.8698, 9.0114],\n", + " [-16.9832, -15.1343, -19.4977, -12.0574, -21.4640, 13.4612],\n", + " [-18.1148, -13.0337, -20.8080, -11.2504, -21.6554, 12.2917],\n", + " [-15.0605, -16.2581, -18.6113, -7.4220, -20.1719, 8.6069]],\n", + " device='cuda:0', grad_fn=) tensor(4.5092, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-16.1585, -12.7229, -17.1500, -9.8196, -20.7888, 10.2071],\n", + " [-16.0596, -11.8572, -16.5405, -10.5472, -19.1893, 11.8997],\n", + " [-16.8690, -15.1268, -17.2011, -10.8024, -20.6433, 12.2531],\n", + " [-18.7775, -17.6282, -19.9257, -12.5080, -21.3468, 14.5402],\n", + " [-14.9997, -14.0050, -15.8017, -11.3393, -19.5409, 10.9783]],\n", + " device='cuda:0', grad_fn=) tensor(4.7581, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.2756, -8.0491, -11.0042, -4.5234, -10.3225, 4.0230],\n", + " [-16.0398, -11.1164, -15.1176, -10.0344, -18.6164, 10.1628],\n", + " [-12.5132, -10.8211, -14.9827, -9.1742, -17.1821, 9.9145],\n", + " [-18.4439, -12.7465, -21.5300, -11.5545, -23.4562, 11.8798],\n", + " [-21.1030, -19.3982, -19.9244, -13.8952, -24.4015, 15.4825]],\n", + " device='cuda:0', grad_fn=) tensor(4.3289, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-22.6398, -17.9340, -21.5053, -14.1912, -25.2801, 14.3820],\n", + " [-16.2111, -10.9896, -17.2865, -10.2637, -19.2123, 10.6698],\n", + " [-18.0670, -15.7895, -19.0299, -10.2522, -21.1252, 12.0592],\n", + " [-15.1078, -12.0009, -16.4496, -9.8840, -18.4530, 11.3191],\n", + " [-12.8361, -9.4217, -13.4028, -7.9373, -17.3825, 8.7229]],\n", + " device='cuda:0', grad_fn=) tensor(4.7339, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-14.4710, -13.7387, -12.5091, -9.9019, -17.1015, 10.5661],\n", + " [-14.6864, -14.2809, -10.7986, -10.0192, -19.4236, 8.6897],\n", + " [-21.4247, -16.7861, -23.6037, -14.4491, -25.0436, 15.8453],\n", + " [-18.1230, -15.4493, -16.6629, -12.0134, -21.2081, 12.0806],\n", + " [-18.1554, -15.5179, -19.0351, -12.4974, -20.5325, 12.8499]],\n", + " device='cuda:0', grad_fn=) tensor(4.8964, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-13.7177, -13.8416, -15.2486, -9.6244, -18.0467, 10.3239],\n", + " [-24.1385, -18.0729, -26.3439, -15.7565, -28.2851, 17.1401],\n", + " [-17.1934, -14.4554, -17.3910, -11.3049, -20.5639, 11.8824],\n", + " [-17.3398, -14.8960, -17.9027, -8.2123, -20.2282, 9.2361],\n", + " [-19.1341, -14.8859, -21.3039, -11.6237, -20.6105, 13.4214]],\n", + " device='cuda:0', grad_fn=) tensor(5.1176, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-20.9768, -18.6529, -21.9484, -14.2502, -24.1503, 16.5473],\n", + " [-23.8706, -19.8372, -24.6863, -16.7051, -25.5844, 17.6105],\n", + " [-17.2622, -14.3734, -17.6110, -9.3845, -15.7502, 8.8129],\n", + " [-13.8973, -12.4226, -15.6001, -11.5369, -15.7504, 11.3192],\n", + " [-14.9060, -14.4255, -19.7606, -13.1144, -20.6059, 12.2337]],\n", + " device='cuda:0', grad_fn=) tensor(5.2479, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-12.6985, -11.5600, -11.3345, -7.3910, -14.6318, 7.7115],\n", + " [-13.4991, -11.5705, -14.6853, -6.5622, -14.5839, 10.1141],\n", + " [ -7.4961, -6.3502, -6.6151, -3.9471, -5.8693, 3.8525],\n", + " [-17.8359, -15.9585, -16.5560, -10.1990, -18.1437, 10.4733],\n", + " [ -8.6907, -9.6264, -9.6112, -1.7396, -10.4617, 3.7027]],\n", + " device='cuda:0', grad_fn=) tensor(3.2103, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-10.6845, -9.8545, -10.1369, -4.1881, -11.8488, 7.3928],\n", + " [-16.7131, -14.5159, -15.4915, -10.1817, -18.0355, 10.3292],\n", + " [-13.8805, -10.1855, -11.8923, -8.8698, -15.4124, 8.3858],\n", + " [-12.4117, -9.6254, -12.4239, -6.2932, -13.4157, 7.7600],\n", + " [-19.5006, -17.0036, -18.6615, -13.4596, -21.4147, 14.0077]],\n", + " device='cuda:0', grad_fn=) tensor(4.0361, device='cuda:0',\n", + " grad_fn=)\n", + "tensor([[-15.3934, -12.1062, -16.5788, -7.5249, -17.8767, 8.9113],\n", + " [-17.2418, -14.8963, -21.0337, -12.5611, -22.3435, 12.9004],\n", + " [-15.2857, -13.2404, -17.3960, -9.4093, -18.8769, 10.5207],\n", + " [-20.7384, -17.1768, -24.1848, -13.4080, -25.2121, 14.0100],\n", + " [-12.2906, -11.0539, -12.2814, -9.7475, -13.5809, 9.4918]],\n", + " device='cuda:0', grad_fn=) tensor(4.5595, device='cuda:0',\n", + " grad_fn=)\n" + ] + } + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "results = []\n", + "model.cuda()\n", + "for i in range(0, ds[\"test\"].shape[0], 5):\n", + " x = TimmInputs(labels=torch.Tensor(ds[\"test\"][i:i+5].labels).cuda(),\n", + " spectrogram=torch.Tensor(np.concat(ds[\"test\"][i:i+5].spectrogram)).unsqueeze(1).cuda())\n", + " out = model(x)\n", + " results.append(np.array(torch.argmax(out[\"logits\"], dim=1).cpu()))\n", + " if i > 500:\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "1268147f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dec69cd5131646a6bcfa0772c7ee371d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/31537 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for i in np.arange(510)[np.concat(results) != 5]:\n", + " print(np.concat(results)[i])\n", + " import IPython\n", + " display(IPython.display.Audio(og_ds[\"test\"][int(i)][\"audio\"][\"path\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "51e71f6b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([ 66, 88, 94, 97, 114, 115, 116, 152, 188, 191, 194, 197, 284])]" + ] + }, + "execution_count": 106, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[np.arange(510)[np.concat(results) != 5]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed62fc90", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "whoot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/whoot_model_training/inferance.py b/whoot_model_training/inferance.py new file mode 100644 index 0000000..c37f18e --- /dev/null +++ b/whoot_model_training/inferance.py @@ -0,0 +1,122 @@ +"""Trains a Mutliclass Model with Pytorch and Huggingface. + +This script can be used to run experiments with different +models and datasets to create any model for bioacoustic classification + +It is intended this script to be heavily modified with each experiment +(say one wants to use a different dataset, one should copy this and change the +extractor!) + +Usage: + $ python train.py /path/to/config.yml + +config.yml should contain frequently changed hyperparameters +""" +import os +import argparse +import yaml + +from whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments +from whoot_model_training.data_extractor import buowset_extractor, raw_audio_extractor +from whoot_model_training.models import TimmModel, TimmInputs, TimmModelConfig +from whoot_model_training import CometMLLoggerSupplement +from train import parse_config, init_env + +from whoot_model_training.preprocessors import ( + MelModelInputPreprocessor +) + +import pickle + +def test(config, model_name=""): + """Highest level logic for inferance. + + Does the following: + - Formats the dataset into an AudioDataset + - Prepares preprocessing for each audio clip + - Builds the model + - Configures and runs the trainer + - Runs evaluation + + Args: + config (dict): the config used for training. Defined in yaml file + TODO + """ + # Extract a new dataset + ds = raw_audio_extractor( + audio_parent_folder="/mnt/restorage/Audiomoth/Raw sound files/2024/RGCB/", + output_folder="data/manual_buowset", + chunk_duration=3 + ) + + # ds = buowset_extractor( + # metadata_csv=config["metadata_csv"], + # parent_path=config["data_path"], + # output_path=config["hf_cache_path"], + # ) + + # Create the model + model = TimmModel.from_pretrained(model_name) + + preprocessor = MelModelInputPreprocessor( + TimmInputs, duration=3 + ) + + ds["train"].set_transform(preprocessor) + # ds["valid"].set_transform(preprocessor) + # ds["test"].set_transform(preprocessor) + + + model_name = "efficientnet_b1" + run_name = f"buowset1.1_{model_name}_ATTEMPT_TO_STUDY_NEW_DATA" + + # trainer = WhootTrainer._load_from_checkpoint(model_name) + + # Run training + training_args = WhootTrainingArguments( + run_name=run_name, + subproject_name=config["SUBPROJECT_NAME"]+"_INFERANCE", + dataset_name=config["DATASET_NAME"], + ) + + # COMMON OPTIONAL ARGS + training_args.num_train_epochs = 5 + training_args.eval_steps = 100 + training_args.per_device_train_batch_size = 16 + training_args.per_device_eval_batch_size = 16 + training_args.dataloader_num_workers = 1 + training_args.run_name = run_name + + trainer = WhootTrainer( + model=model, + dataset=ds, + training_args=training_args, + logger=CometMLLoggerSupplement( + augmentations=None, + name=training_args.run_name + ), + ) + + # print(ds["train"].shape, ds["test"].shape, ds["valid"].shape) + # input() + + out = trainer.predict(ds["train"], metric_key_prefix="train") + print(out) + with open(run_name + ".pkl", mode="wb") as f: + pickle.dump(out, f) + # trainer.evaluate(ds["test"], metric_key_prefix="test") + # trainer.evaluate(ds["valid"], metric_key_prefix="valid") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Input config path") + parser.add_argument("config", type=str, help="Path to config.yml") + parser.add_argument( + "--model_name", + required=False, + help="path to weights or hugging face repo id", + default="/home/sean/whoot/checkpoint-4985") + args = parser.parse_args() + _config = parse_config(args.config) + + init_env(_config) + test(_config, model_name=args.model_name) diff --git a/whoot_model_training/train.py b/whoot_model_training/train.py index 316f860..a8658cf 100644 --- a/whoot_model_training/train.py +++ b/whoot_model_training/train.py @@ -24,6 +24,7 @@ from whoot_model_training.preprocessors import ( MelModelInputPreprocessor ) +from whoot_model_training.preprocessors.spectrogram_preprocessors import SpectrogramParams # Uncomment for use with data augmentation # from pyha_analyzer.preprocessors import MixItUp, ComposeAudioLabel @@ -69,9 +70,11 @@ def train(config): ) # Create the model - run_name = "flac_pylint_test_efficientnet_b1_buowset" + model_name = "efficientnet_b1" + + run_name = f"buowset1.1_{model_name}" model_config = TimmModelConfig( - timm_model="efficientnet_b1", + timm_model=model_name, num_classes=ds.get_num_classes()) model = TimmModel(model_config) @@ -101,13 +104,30 @@ def train(config): # ) # ]) + spectrogram_params = SpectrogramParams() + # spectrogram_params = SpectrogramParams( + # n_mels = 224, + # hop_length = 286, + # ) + # """Dataclass for spectrogram Parameters. + + # n_fft: (int) number of fft bins + # hop_length (int) skip count + # power: (float) usually 2 + # n_mels: (int) number of mel bins + # """ + # n_fft: int = 2048 + # hop_length: int = 256 + # power: float = 2.0 + # n_mels: int = 256 + # Online preprocessors prepare data for training train_preprocessor = MelModelInputPreprocessor( - TimmInputs, duration=3 + TimmInputs, duration=3, spectrogram_params=spectrogram_params ) preprocessor = MelModelInputPreprocessor( - TimmInputs, duration=3 + TimmInputs, duration=3, spectrogram_params=spectrogram_params ) ds["train"].set_transform(train_preprocessor) @@ -122,11 +142,11 @@ def train(config): ) # COMMON OPTIONAL ARGS - training_args.num_train_epochs = 2 + training_args.num_train_epochs = 5 training_args.eval_steps = 100 - training_args.per_device_train_batch_size = 32 - training_args.per_device_eval_batch_size = 32 - training_args.dataloader_num_workers = 36 + training_args.per_device_train_batch_size = 16 + training_args.per_device_eval_batch_size = 16 + training_args.dataloader_num_workers = 1 training_args.run_name = run_name trainer = WhootTrainer( diff --git a/whoot_model_training/whoot_model_training/data_extractor/__init__.py b/whoot_model_training/whoot_model_training/data_extractor/__init__.py index 7b5e158..9b3ef7f 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/__init__.py +++ b/whoot_model_training/whoot_model_training/data_extractor/__init__.py @@ -8,10 +8,11 @@ buowset_binary_extractor, ) from .esc50_extractor import esc50_extractor +from .raw_audio_extractor import raw_audio_extractor from .Jacuzzi_Olden_extractor import Jacuzzi_Olden_Extractor from .xc_extractor import xc_extractor -__all__ = ["buowset_extractor", "buowset_binary_extractor", "esc50_extractor", "Jacuzzi_Olden_Extractor", "xc_extractor"] +__all__ = ["buowset_extractor", "buowset_binary_extractor", "esc50_extractor", "Jacuzzi_Olden_Extractor", "xc_extractor", "raw_audio_extractor"] def concat_dataset(datasetA, datasetB): for split in datasetA.keys(): @@ -28,4 +29,4 @@ def concat_dataset(datasetA, datasetB): # should be able to merge # Metadata - # Consider dropping all non-required columns, will make merge easier \ No newline at end of file + # Consider dropping all non-required columns, will make merge easier diff --git a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py new file mode 100644 index 0000000..443e4e3 --- /dev/null +++ b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py @@ -0,0 +1,247 @@ +"""Processes raw audio folders. + +Extractor for general, typically unlabeled soundscape recordings + +Fits as much as possible to the AudioDataset standard but +NOT INTENDED FOR TRAINING + +Rather just a placeholder to help inferance work +""" +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union +import os +import numpy as np +from datasets import ( + load_dataset, + Audio, + concatenate_datasets, + DatasetDict, + ClassLabel, + Sequence, + Dataset, + table +) +import librosa +from math import floor +from tqdm import tqdm +import pyarrow as pa + +from ..dataset import AudioDataset + + + +class SubAudio(Audio): + """Extends Audio to take a chunks of data. + + Uses code from the Hugging Face Audio Class + https://github.com/huggingface/datasets/blob/5dc1a179783dff868b0547c8486268cfaea1ea1f/src/datasets/features/audio.py#L24 + + The Audio Column of a HuggingFace dataset + handles loading in data from a given file + + What is nice is it streams data: it doesn't get loaded into + memory until it is needed via the path + + However, if we wanted to load in a chunk of data (some segment) + We would need to load it as an array instead of a path + And it gets loaded into memory. Huge issue with large audio datasets. + + By default HF doesn't support chunking, so this class should handle chunking + During streaming rather than during dataset creation + + You can use it the same way you might with the Audio class. In fact, with normal processing + it handles the same way! + + To use the chunking feature, create a Audio row with the following parameters + - path: as is with Audio + - sampling_rate: as is with Audio + - offset: NEW, offset in seconds of when to start taking audio data + - duration: NEW, duration from offset in seconds for how much data to collect + + You need both offset and duration to load in the chunk, otherwise it will load the full file. + """ + + pa_type: ClassVar[Any] = pa.struct({ + "bytes": pa.binary(), + "path": pa.string(), + "offset": pa.int64(), + "duration": pa.int64() + }) + + def __call__(self): + return self.pa_type + + def encode_example(self, value) -> dict: + if ( + isinstance(value, dict) + and value.get("offset") + and value.get("duration") + and value.get("path") is not None + and os.path.isfile(value["path"]) + ): + y, sr = librosa.load(path = value["path"], offset=value["offset"], duration=value["duration"]) + value["array"] = y + value["sampling_rate"] = sr + encoded = super().encode_example(value) + encoded["offset"] = value["offset"] + encoded["duration"] = value["duration"] + encoded["path"] = encoded["path"] + return encoded + return super().encode_example(value) + + def decode_example(self, value, token_per_repo_id=None) -> dict: + # print("d4ecode", value) + if ( + isinstance(value, dict) + and "offset" in value + and "duration" in value + and value.get("bytes") is None + and value.get("path") is not None + and os.path.isfile(value["path"]) + ): + y, sr = librosa.load(path = value["path"], offset=value["offset"], duration=value["duration"]) + return { + "path": value["path"], + "array": y, + "sampling_rate": sr, + "offset": value["offset"], + "duration": value["duration"]} + elif ( + isinstance(value, dict) + and value.get("offset") + and value.get("duration") + and value.get("bytes") is not None + ): + decoded = super().decode_example(value, token_per_repo_id=token_per_repo_id) + decoded["offset"] = value["offset"] + decoded["duration"] = value["duration"] + return decoded + return super().decode_example(value, token_per_repo_id=token_per_repo_id) + + def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.StructArray: + # print("cast_storage real") + if pa.types.is_struct(storage.type): + if storage.type.get_field_index("bytes") >= 0: + bytes_array = storage.field("bytes") + else: + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + if storage.type.get_field_index("path") >= 0: + path_array = storage.field("path") + else: + path_array = pa.array([None] * len(storage), type=pa.string()) + if storage.type.get_field_index("offset") >= 0: + offset_array = storage.field("offset") + else: + offset_array = pa.array([None] * len(storage), type=pa.int64()) + if storage.type.get_field_index("duration") >= 0: + duration_array = storage.field("duration") + else: + duration_array = pa.array([None] * len(storage), type=pa.int64()) + storage = pa.StructArray.from_arrays([bytes_array, path_array, offset_array, duration_array], ["bytes", "path", "offset", "duration"], mask=storage.is_null()) + return table.array_cast(storage, self.pa_type) + +from datasets.features.features import _FEATURE_TYPES, FeatureType +_FEATURE_TYPES[SubAudio.__name__] = SubAudio +FeatureType = Union[FeatureType, SubAudio] + +def get_empty_dict(): + return { + "audio": [], + "file_path": [], + "labels": [], + } + +def get_array_chunks_from_memory( + parent_folder, + chunk_length_sec = 5, + no_class_idx = 5, + output_path = "/data/manual_buowset" +): + + new_rows = get_empty_dict() + _datasets = [] + for root, dirs, files in tqdm(os.walk(parent_folder), desc="All Folders"): + for filename in tqdm(files, leave=False, desc=f"file in dir"): + try: + if not filename.lower().endswith((".wav", ".mp3", ".flac", ".ogg", ".m4a")): + continue + file_path = os.path.join(root, filename) + try: + clip_length = librosa.get_duration(path=file_path) + sr = librosa.get_samplerate(path=file_path) + except BaseException as e: + print(file_path, "failed stat read", "continuing") + continue + for i in tqdm(range(0, int(floor(clip_length)), chunk_length_sec), leave=False, desc=f"{filename}"): + new_rows["audio"].append({ + "path": file_path, + # "sampling_rate": sr, + "offset": i, + "duration": chunk_length_sec + }) + new_rows["file_path"].append(filename) + new_rows["labels"].append(no_class_idx) + # break #TODO REMOVE + except BaseException as e: + print(e) + + # This helps make sure stuff isn't loaded into memory + # Hopefully + file_ds = Dataset.from_dict(new_rows).cast_column("audio", SubAudio()) + new_rows = get_empty_dict() + _datasets.append(file_ds) + # break #TODO REMOVE + # if len(_datasets) > 1: #TODO REMOVE + # break #TODO REMOVE + + return concatenate_datasets(_datasets) + +def one_hot_encode(row: dict, classes: list): + """One hot Encodes a list of labels. + + Args: + row (dict): row of data in a dataset containing a labels column + classes: a list of classes + """ + one_hot = np.zeros(len(classes)) + one_hot[row["labels"]] = 1 + row["labels"] = np.array(one_hot, dtype=int) + return row + +def raw_audio_extractor( + audio_parent_folder:str = "", + sr=32_000, + class_list = ["cluck", "coocoo", "twitter", "alarm", "chick begging", "no_buow"], + chunk_duration = -1, + output_folder = "" + +): + """Extracts raw, unlabeled data in the buowset format into an AudioDataset. + + Args: + audio_parent_folder (str): Path to the parent folder for all audio data. + Note its assumed the audio filepath + in the csv is relative to parent_path + sr (int): Sample Rate of the audio files Default: 32_000 + + Returns: + (AudioDataset): See dataset.py, AudioDatasets are consider + the universal dataset for the training pipeline. + """ + + dataset = get_array_chunks_from_memory( + parent_folder=audio_parent_folder, + chunk_length_sec=chunk_duration, + ) + + # # # # Convert to a uniform one_hot encoding for classes + dataset = dataset.class_encode_column("labels") + multilabel_class_label = Sequence(ClassLabel(names=class_list)) + dataset = dataset.map(lambda row: one_hot_encode(row, class_list)).cast_column( + "labels", multilabel_class_label + ) + + ds = AudioDataset( + DatasetDict({"train": dataset, "valid": dataset, "test": dataset}) + ) + return ds + diff --git a/whoot_model_training/whoot_model_training/models/__init__.py b/whoot_model_training/whoot_model_training/models/__init__.py index 3c46187..6bc07f3 100644 --- a/whoot_model_training/whoot_model_training/models/__init__.py +++ b/whoot_model_training/whoot_model_training/models/__init__.py @@ -5,6 +5,7 @@ """ from .timm_model import TimmModel, TimmInputs, TimmModelConfig +from .hf_models import HFModel, HFModelConfig, HFInput from .model import Model, ModelInput, ModelOutput from .few_shot_model import PerchEmbeddingInput, PerchFewShotModel, FewShotModelConfig @@ -12,6 +13,9 @@ "TimmModel", "TimmInputs", "TimmModelConfig", + "HFModel", + "HFModelConfig", + "HFInput" "Model", "ModelInput", "ModelOutput", diff --git a/whoot_model_training/whoot_model_training/models/hf_models.py b/whoot_model_training/whoot_model_training/models/hf_models.py new file mode 100644 index 0000000..b01bf14 --- /dev/null +++ b/whoot_model_training/whoot_model_training/models/hf_models.py @@ -0,0 +1,142 @@ +from transformers import AutoFeatureExtractor, AutoModel +import librosa + +"""Wrapper around the timms model zoo! + +See https://timm.fast.ai/ + +Timm model zoo good for computer vision models +Like CNNs, which are useful for spectrograms + +Great repo for models, but currently using this for demoing pipeline +""" + +import timm +from torch import nn +from transformers import PretrainedConfig + +from .model import Model, ModelInput, ModelOutput, has_required_inputs + + +class HFInput(ModelInput): + """Input for TimmModels. + + Specifies TimmModels needs labels and spectrograms that are Tensors + """ + def __init__(self, labels, spectrogram=None, waveform=None, extractor_path="DBD-research-group/Bird-MAE-Base"): + """Creates TimmInputs. + + Args: + labels: the data's label for this batch + audio_data: some audio_data, basically this has a feature extractor for it + """ + + # print("fe works") + feature_extractor = AutoFeatureExtractor.from_pretrained(extractor_path, trust_remote_code=True) + + mel_spectrogram = feature_extractor(waveform) + + # # Can use inputs to verify correct shape for upstream model + # assert spectrogram.shape[1:] == (1, 100, 100) + super().__init__(labels, waveform=waveform, spectrogram=mel_spectrogram) + self.labels = labels + self.spectrogram = mel_spectrogram + +class HFModelConfig(PretrainedConfig): + """Config for Timm Model Zoo Models!""" + def __init__( + self, + path="DBD-research-group/Bird-MAE-Huge", + num_classes=6, + embeddings_size=1280, + **kwargs + ): + """Creates Config. + + Args: + path (str): url to pull from hf model zoo + num_classes (int): number of classes in dataset, for cls + embeddings_size (int): size of output of model + """ + self.path = path + self.num_classes = num_classes + self.embeddings_size = embeddings_size + super().__init__(**kwargs) + + +class HFModel(Model, nn.Module): + """Model that uses a timm's model.""" + config_class = HFModelConfig + + def __init__( + self, + config: HFModelConfig + ): + """Init for TimmModel. + + kwargs: + timm_model (str): name of model backbone from timms to use, + Default: "resnet34" + pretrained (bool): use a pretrained model from timms, Default: True + in_chans (int): number of channels of audio: Default: 1 + num_classes (int): number of classes in the dataset: Default 6 + loss (any): custom loss function Default: BCEWithLogitsLoss + """ + super().__init__() + self.input_format = HFInput + self.output_format = ModelOutput + self.config = config + assert config.num_classes > 0 + + # Deep learning CNN backbone + self.backbone = AutoModel.from_pretrained(config.path, trust_remote_code=True) + + # Unsure if 1000 is default for all timm models. Need to check this + self.linear = nn.Linear(config.embeddings_size, config.num_classes) + + # different losses if you want to train for different problems + # BCEWithLogitsLoss is default as for Bioacoustics, the problem tends + # multilabel! + # the probability of class A occurring doesn't + # change the probability of Class B + # Many individuals can make calls at the same time! + self.loss = nn.BCEWithLogitsLoss() + + def set_custom_loss(self, loss_fn): + """Set a different loss function. + + For cases where we don't want BCEWithLogitsLoss + + Args: + loss_fn: Function to compute loss, ideally in pytorch + """ + self.loss = loss_fn + + @has_required_inputs() + def forward(self, x: HFInput) -> ModelOutput: + """Model forward function. + + Args: + x: (TimmInputs): The specific input format for Timm Models + + Returns + (ModelOutput): The model output (logits), + latent space representations (embeddings), loss and labels. + """ + embed = self.backbone(x.spectrogram.to(self.device)).last_hidden_state + logits = self.linear(embed) + loss = self.loss(logits, x.labels) + + return ModelOutput( + logits=logits, + embeddings=embed, + loss=loss, + labels=x.labels + ) + + + + + + + diff --git a/whoot_model_training/whoot_model_training/models/timm_model.py b/whoot_model_training/whoot_model_training/models/timm_model.py index aac2094..dc16e03 100644 --- a/whoot_model_training/whoot_model_training/models/timm_model.py +++ b/whoot_model_training/whoot_model_training/models/timm_model.py @@ -20,7 +20,7 @@ class TimmInputs(ModelInput): Specifies TimmModels needs labels and spectrograms that are Tensors """ - def __init__(self, labels, waveform=None, spectrogram=None): + def __init__(self, labels, spectrogram=None): """Creates TimmInputs. Args: @@ -30,7 +30,7 @@ def __init__(self, labels, waveform=None, spectrogram=None): """ # # Can use inputs to verify correct shape for upstream model # assert spectrogram.shape[1:] == (1, 100, 100) - super().__init__(labels, waveform, spectrogram) + super().__init__(labels, waveform=None, spectrogram=spectrogram) self.labels = labels self.spectrogram = spectrogram diff --git a/whoot_model_training/whoot_model_training/preprocessors/__init__.py b/whoot_model_training/whoot_model_training/preprocessors/__init__.py index df579e5..65b121f 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/__init__.py +++ b/whoot_model_training/whoot_model_training/preprocessors/__init__.py @@ -8,8 +8,7 @@ """ from .base_preprocessor import ( - MelModelInputPreprocessor, - WaveformInputPreprocessor + MelModelInputPreprocessor, WaveformInputPreprocessor ) from .spectrogram_preprocessors import ( BuowMelSpectrogramPreprocessors diff --git a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py index dc63ef4..8d492fe 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py +++ b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py @@ -147,6 +147,7 @@ def __init__( """ wav_preprocessor = WaveformPreprocessors( duration=duration, + sr=32_000, augments=augments, ) super().__init__(wav_preprocessor, model_input) diff --git a/whoot_model_training/whoot_model_training/preprocessors/inferance_wrap.py b/whoot_model_training/whoot_model_training/preprocessors/inferance_wrap.py new file mode 100644 index 0000000..9b46e1e --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/inferance_wrap.py @@ -0,0 +1,7 @@ +class MelModelInputPreprocessor(): + def __init__(self, preprocessor): + self.preprocessor = preprocessor + + def __call__(self, batch_input): + assert bat + pass \ No newline at end of file diff --git a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py index 555a801..089164e 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py @@ -77,7 +77,7 @@ def __call__(self, batch): new_labels = [] for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] - y, sr = librosa.load(path=batch["audio"][item_idx]["path"]) + y, sr = batch["audio"][item_idx]["array"],batch["audio"][item_idx]["sampling_rate"] start = 0 # Handle out of bound issues @@ -114,7 +114,7 @@ def __call__(self, batch): new_audio.append(mels) new_labels.append(label) - batch["audio"] = new_audio + batch["audio"] = np.concatenate(new_audio) batch["labels"] = np.array(new_labels, dtype=np.float32) return batch diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py index 5651e10..5951110 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -48,6 +48,7 @@ class WaveformPreprocessors(PreProcessorBase): def __init__( self, duration=5, + sr=None, augments: Augmentations = Augmentations(), ): """Defines a BuowMelSpectrogramPreprocessors. @@ -60,6 +61,7 @@ def __init__( """ self.duration = duration self.augments = augments + self.sr = sr # # Below parameter defaults from https://arxiv.org/pdf/2403.10380 pg 25 # self.n_fft = spectrogram_params.n_fft @@ -72,12 +74,13 @@ def __init__( def __call__(self, batch): """Process a batch of data from an AudioDataset.""" + # print("preprocessor", len(batch), len(batch["audio"]), len(batch["labels"])) new_audio = [] new_labels = [] for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] try: - y, sr = librosa.load(path=batch["audio"][item_idx]["path"]) + y, sr = librosa.load(path=batch["audio"][item_idx]["path"], sr=self.sr) except Exception as e: print(e) print("File Likely is corrupted, moving on") @@ -103,6 +106,7 @@ def __call__(self, batch): batch["audio"] = new_audio batch["labels"] = np.array(new_labels, dtype=np.float32) + # print(len(batch["audio"]), len(batch["labels"])) return batch diff --git a/whoot_model_training/whoot_model_training/trainer.py b/whoot_model_training/whoot_model_training/trainer.py index 93f8702..e6a4cd8 100644 --- a/whoot_model_training/whoot_model_training/trainer.py +++ b/whoot_model_training/whoot_model_training/trainer.py @@ -17,7 +17,9 @@ from .metrics import WhootMutliClassMetrics from .dataset import AudioDataset from .models import Model - +import torch +import numpy as np +from tqdm import tqdm class WhootTrainingArguments(PyhaTrainingArguments): """Holds arguments use for training.""" @@ -103,3 +105,16 @@ def __init__( preprocessor, model.output_format.ignore_keys ) + def predict( + self, test_dataset: AudioDataset, ignore_keys = None, metric_key_prefix: str = "test" + ): + + test_dataloader = self.get_test_dataloader(test_dataset) + + preds = [] + for batch in tqdm(test_dataloader): + preds.append(self.model(self.model.input_format(**batch))["logits"].detach().cpu()) + + dataset = test_dataset.to_dict() + dataset["pred"] = torch.concat(preds).detach().numpy() + return dataset \ No newline at end of file From 0aee2957a2c85156a2103a1e2da4cf8830b2159b Mon Sep 17 00:00:00 2001 From: sean1572 Date: Wed, 15 Oct 2025 11:50:54 -0700 Subject: [PATCH 05/18] Reslove train.py conflicts --- whoot_model_training/train.py | 84 ++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 27 deletions(-) diff --git a/whoot_model_training/train.py b/whoot_model_training/train.py index a8658cf..1566cb5 100644 --- a/whoot_model_training/train.py +++ b/whoot_model_training/train.py @@ -25,6 +25,12 @@ MelModelInputPreprocessor ) from whoot_model_training.preprocessors.spectrogram_preprocessors import SpectrogramParams +from whoot_model_training.preprocessors.augmentations import ( + Gain, + PolarityInversion, + MixItUp, + ComposeAudioLabel +) # Uncomment for use with data augmentation # from pyha_analyzer.preprocessors import MixItUp, ComposeAudioLabel @@ -63,10 +69,17 @@ def train(config): config (dict): the config used for training. Defined in yaml file """ # Extract the dataset - ds = buowset_extractor( - metadata_csv=config["metadata_csv"], - parent_path=config["data_path"], - output_path=config["hf_cache_path"], + # ds = buowset_extractor( + # metadata_csv=config["metadata_csv"], + # parent_path=config["data_path"], + # output_path=config["hf_cache_path"], + # ) + + from whoot_model_training.whoot_model_training.data_extractor import xc_extractor + + ds = xc_extractor( + XC_dataset_json_path="/home/sean/whoot/data/san_diego_xc_aux/xc_meta_aux.json", + parent_path="/home/sean/whoot/data/san_diego_xc_aux/xeno-canto" ) # Create the model @@ -81,28 +94,45 @@ def train(config): # Preprocessors # Uncomment if doing work with data augmentation - # # Augmentations - # wav_augs = ComposeAudioLabel([ - # # AddBackgroundNoise( #We don't have background noise yet... - # # sounds_path="data_birdset/background_noise", - # # min_snr_db=10, - # # max_snr_db=30, - # # noise_transform=PolarityInversion(), - # # p=0.8 - # # ), - # Gain( - # min_gain_db = -12, - # max_gain_db = 12, - # p = 0.8 - # ), - # MixItUp( - # dataset_ref=ds["train"], - # min_snr_db=10, - # max_snr_db=30, - # noise_transform=PolarityInversion(), - # p=0.8 - # ) - # ]) + # Augmentations + wav_augs = ComposeAudioLabel([ + # AddBackgroundNoise( #We don't have background noise yet... + # sounds_path="data_birdset/background_noise", + # min_snr_db=10, + # max_snr_db=30, + # noise_transform=PolarityInversion(), + # p=0.8 + # ), + Gain( + min_gain_db = -12, + max_gain_db = 12, + p = 0.8 + ), + # MixItUp( + # dataset_ref=ds["train"], + # min_snr_db=10, + # max_snr_db=30, + # noise_transform=PolarityInversion(), + # p=0.8 + # ) + ]) + + spectrogram_params = SpectrogramParams() + # spectrogram_params = SpectrogramParams( + # n_mels = 224, + # hop_length = 286, + # ) + # """Dataclass for spectrogram Parameters. + + # n_fft: (int) number of fft bins + # hop_length (int) skip count + # power: (float) usually 2 + # n_mels: (int) number of mel bins + # """ + # n_fft: int = 2048 + # hop_length: int = 256 + # power: float = 2.0 + # n_mels: int = 256 spectrogram_params = SpectrogramParams() # spectrogram_params = SpectrogramParams( @@ -160,7 +190,7 @@ def train(config): ) trainer.train() - model.save_pretrained("model_checkpoints/test") + model.save_pretrained("model_checkpoints/xc_aux") def init_env(config: dict): From b5510f2ad98617b5adb7319302117e5ff558dd1d Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 17 Oct 2025 14:13:46 -0700 Subject: [PATCH 06/18] Fix bugs with birdmae data loading --- pyproject.toml | 4 +-- .../data_extractor/xc_extractor.py | 10 ++++-- .../whoot_model_training/models/hf_models.py | 31 ++++++++++++------- .../whoot_model_training/models/model.py | 3 ++ 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3382876..e8b2174 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "pyyaml>=6.0.2", "scikit-learn>=1.7.0", "soundfile>=0.13.1", + "torchaudio>=2.8.0", "tqdm>=4.67.1", ] @@ -56,9 +57,6 @@ notebooks = [ "matplotlib>=3.10.6", "seaborn>=0.13.2", ] -birdnet = [ - "birdnet>=0.1.7", -] [packages.index] diff --git a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py index 17b63e8..60d2756 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py @@ -19,6 +19,7 @@ DatasetDict, ClassLabel, Sequence, + load_from_disk, ) from ..dataset import AudioDataset @@ -99,10 +100,14 @@ class XCParams(): def xc_extractor( XC_dataset_json_path, parent_path, + cache_path="data/san_diego_xc_aux/cache", params: XCParams = XCParams(), bad_file_path="data/xc_bad_file" ): - + if os.path.exists(cache_path): + return load_from_disk(cache_path) + + with open(XC_dataset_json_path, mode="r") as f: xc_recordings_paged = json.load(f) @@ -164,6 +169,7 @@ def xc_extractor( "test": test_val["test"]}) ) - # dataset.save_to_disk(output_path) + # os.makedirs(cache_path, exist_ok=True) + # dataset.save_to_disk(cache_path) return dataset diff --git a/whoot_model_training/whoot_model_training/models/hf_models.py b/whoot_model_training/whoot_model_training/models/hf_models.py index b01bf14..7ec538a 100644 --- a/whoot_model_training/whoot_model_training/models/hf_models.py +++ b/whoot_model_training/whoot_model_training/models/hf_models.py @@ -13,17 +13,19 @@ import timm from torch import nn +import torch +from contextlib import nullcontext from transformers import PretrainedConfig from .model import Model, ModelInput, ModelOutput, has_required_inputs -class HFInput(ModelInput): +class HFInput(): """Input for TimmModels. Specifies TimmModels needs labels and spectrograms that are Tensors """ - def __init__(self, labels, spectrogram=None, waveform=None, extractor_path="DBD-research-group/Bird-MAE-Base"): + def __init__(self, labels=None, spectrogram=None, waveform=None, extractor_path="DBD-research-group/Bird-MAE-Base"): """Creates TimmInputs. Args: @@ -32,16 +34,20 @@ def __init__(self, labels, spectrogram=None, waveform=None, extractor_path="DBD """ # print("fe works") - feature_extractor = AutoFeatureExtractor.from_pretrained(extractor_path, trust_remote_code=True) + self.feature_extractor = AutoFeatureExtractor.from_pretrained(extractor_path, trust_remote_code=True) + # self.feature_extractor = AutoFeatureExtractor.from_pretrained(extractor_path, trust_remote_code=True) - mel_spectrogram = feature_extractor(waveform) + # mel_spectrogram = feature_extractor(waveform) - # # Can use inputs to verify correct shape for upstream model - # assert spectrogram.shape[1:] == (1, 100, 100) - super().__init__(labels, waveform=waveform, spectrogram=mel_spectrogram) - self.labels = labels - self.spectrogram = mel_spectrogram + # # # Can use inputs to verify correct shape for upstream model + # # assert spectrogram.shape[1:] == (1, 100, 100) + # self.labels = labels + # self.spectrogram = mel_spectrogram + def __call__(self, labels, spectrogram=None, waveform=None): + mel_spectrogram = self.feature_extractor(waveform) + return ModelInput(labels, waveform=None, spectrogram=mel_spectrogram) + class HFModelConfig(PretrainedConfig): """Config for Timm Model Zoo Models!""" def __init__( @@ -49,6 +55,7 @@ def __init__( path="DBD-research-group/Bird-MAE-Huge", num_classes=6, embeddings_size=1280, + freeze_backbone = True, **kwargs ): """Creates Config. @@ -61,6 +68,7 @@ def __init__( self.path = path self.num_classes = num_classes self.embeddings_size = embeddings_size + self.freeze_backbone = freeze_backbone super().__init__(**kwargs) @@ -83,7 +91,7 @@ def __init__( loss (any): custom loss function Default: BCEWithLogitsLoss """ super().__init__() - self.input_format = HFInput + self.input_format = ModelInput self.output_format = ModelOutput self.config = config assert config.num_classes > 0 @@ -123,7 +131,8 @@ def forward(self, x: HFInput) -> ModelOutput: (ModelOutput): The model output (logits), latent space representations (embeddings), loss and labels. """ - embed = self.backbone(x.spectrogram.to(self.device)).last_hidden_state + with torch.no_grad() if self.config.freeze_backbone else nullcontext(): + embed = self.backbone(x.spectrogram.to(self.device)).last_hidden_state logits = self.linear(embed) loss = self.loss(logits, x.labels) diff --git a/whoot_model_training/whoot_model_training/models/model.py b/whoot_model_training/whoot_model_training/models/model.py index aedc56f..cfab770 100644 --- a/whoot_model_training/whoot_model_training/models/model.py +++ b/whoot_model_training/whoot_model_training/models/model.py @@ -138,6 +138,9 @@ def __init__( "waveform": waveform, "spectrogram": spectrogram }) + self.labels = labels + self.waveform = waveform + self.spectrogram = spectrogram def items(self): """Get all items in dict. From 3e95f110d75551a5156416ef104aaf285c365d82 Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 17 Oct 2025 16:02:46 -0700 Subject: [PATCH 07/18] Linted (round 1) --- .../data_extractor/raw_audio_extractor.py | 60 ++++++++++----- .../data_extractor/xc_extractor.py | 76 ++++++++++++++----- .../whoot_model_training/models/__init__.py | 8 +- .../models/few_shot_model.py | 62 ++++++++++----- .../whoot_model_training/models/hf_models.py | 75 +++++++++--------- .../preprocessors/augmentations.py | 14 +++- .../preprocessors/base_preprocessor.py | 5 -- .../preprocessors/inferance_wrap.py | 7 -- .../spectrogram_preprocessors.py | 5 +- .../preprocessors/waveform_preprocessors.py | 18 +++-- .../whoot_model_training/trainer.py | 27 +++++-- 11 files changed, 229 insertions(+), 128 deletions(-) delete mode 100644 whoot_model_training/whoot_model_training/preprocessors/inferance_wrap.py diff --git a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py index 443e4e3..ceafe77 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py @@ -143,6 +143,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.Str _FEATURE_TYPES[SubAudio.__name__] = SubAudio FeatureType = Union[FeatureType, SubAudio] + def get_empty_dict(): return { "audio": [], @@ -150,28 +151,38 @@ def get_empty_dict(): "labels": [], } + def get_array_chunks_from_memory( parent_folder, - chunk_length_sec = 5, - no_class_idx = 5, - output_path = "/data/manual_buowset" + chunk_length_sec=5, + no_class_idx=5, + output_path="/data/manual_buowset" ): - new_rows = get_empty_dict() _datasets = [] for root, dirs, files in tqdm(os.walk(parent_folder), desc="All Folders"): - for filename in tqdm(files, leave=False, desc=f"file in dir"): + for filename in tqdm(files, leave=False, desc="file in dir"): try: - if not filename.lower().endswith((".wav", ".mp3", ".flac", ".ogg", ".m4a")): + if not filename.lower().endswith( + (".wav", ".mp3", ".flac", ".ogg", ".m4a") + ): continue file_path = os.path.join(root, filename) try: clip_length = librosa.get_duration(path=file_path) - sr = librosa.get_samplerate(path=file_path) + # sr = librosa.get_samplerate(path=file_path) except BaseException as e: - print(file_path, "failed stat read", "continuing") + print(file_path, "failed stat read", "continuing", e) continue - for i in tqdm(range(0, int(floor(clip_length)), chunk_length_sec), leave=False, desc=f"{filename}"): + for i in tqdm( + range( + 0, + int(floor(clip_length)), + chunk_length_sec + ), + leave=False, + desc=f"{filename}" + ): new_rows["audio"].append({ "path": file_path, # "sampling_rate": sr, @@ -186,7 +197,9 @@ def get_array_chunks_from_memory( # This helps make sure stuff isn't loaded into memory # Hopefully - file_ds = Dataset.from_dict(new_rows).cast_column("audio", SubAudio()) + file_ds = Dataset.from_dict( + new_rows + ).cast_column("audio", SubAudio()) new_rows = get_empty_dict() _datasets.append(file_ds) # break #TODO REMOVE @@ -195,6 +208,7 @@ def get_array_chunks_from_memory( return concatenate_datasets(_datasets) + def one_hot_encode(row: dict, classes: list): """One hot Encodes a list of labels. @@ -207,18 +221,26 @@ def one_hot_encode(row: dict, classes: list): row["labels"] = np.array(one_hot, dtype=int) return row + def raw_audio_extractor( - audio_parent_folder:str = "", + audio_parent_folder: str = "", sr=32_000, - class_list = ["cluck", "coocoo", "twitter", "alarm", "chick begging", "no_buow"], - chunk_duration = -1, - output_folder = "" - + class_list=[ + "cluck", + "coocoo", + "twitter", + "alarm", + "chick begging", + "no_buow" + ], + chunk_duration=-1, + output_folder="" ): """Extracts raw, unlabeled data in the buowset format into an AudioDataset. Args: - audio_parent_folder (str): Path to the parent folder for all audio data. + audio_parent_folder (str): Path to the parent folder for all audio + data. Note its assumed the audio filepath in the csv is relative to parent_path sr (int): Sample Rate of the audio files Default: 32_000 @@ -227,7 +249,6 @@ def raw_audio_extractor( (AudioDataset): See dataset.py, AudioDatasets are consider the universal dataset for the training pipeline. """ - dataset = get_array_chunks_from_memory( parent_folder=audio_parent_folder, chunk_length_sec=chunk_duration, @@ -236,7 +257,9 @@ def raw_audio_extractor( # # # # Convert to a uniform one_hot encoding for classes dataset = dataset.class_encode_column("labels") multilabel_class_label = Sequence(ClassLabel(names=class_list)) - dataset = dataset.map(lambda row: one_hot_encode(row, class_list)).cast_column( + dataset = dataset.map( + lambda row: one_hot_encode(row, class_list) + ).cast_column( "labels", multilabel_class_label ) @@ -244,4 +267,3 @@ def raw_audio_extractor( DatasetDict({"train": dataset, "valid": dataset, "test": dataset}) ) return ds - diff --git a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py index 60d2756..4611734 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py @@ -13,7 +13,6 @@ import numpy as np from datasets import ( - load_dataset, Dataset, Audio, DatasetDict, @@ -26,22 +25,28 @@ import json import librosa + def filter_by_count(ds, col="en", threshold=10): + """Limit species list to species with some amount of species.""" count_by_species = Counter(ds[col]) - return ds.filter(lambda row: count_by_species[row] > threshold, input_columns=[col]) + return ds.filter( + lambda row: count_by_species[row] > threshold, + input_columns=[col] + ) def filter_xc_data(row: dict): - """ In personal experience, raw XC data is very messy + """In personal experience, raw XC data is very messy. + Some files get coruptted - This intention checks to see if loading files is possible for the frist place + This intention checks to see if loading files is + possible for the frist place """ - file_path = row["filepath"] try: # Heuristic, if we can load 3 seconds, file is probably okay # Prevents some files from taking forever - librosa.load(path=file_path, duration=3) + librosa.load(path=file_path, duration=3) return True except Exception as e: print(e, file_path) @@ -60,13 +65,21 @@ def one_hot_encode(row: dict, classes: list): row["labels"] = np.array(one_hot, dtype=float) return row + def convert_audio_to_flac(row, error_path="bad_files", col="audio"): + """Convert any audio to flac for better compression. + + Args: + row: row from hugging face table + error_path: folder to dump broken files + col: column with audio path + """ file_path = row[col] flac_path = Path(file_path).parent / (Path(file_path).stem + ".flac") if os.path.exists(flac_path): row[col] = str(flac_path) if os.path.exists(file_path): - os.remove(file_path) # Remove origional file, we don't need it + os.remove(file_path) # Remove origional file, we don't need it return row try: wav_audio = AudioSegment.from_file(file_path) @@ -75,15 +88,23 @@ def convert_audio_to_flac(row, error_path="bad_files", col="audio"): if os.path.exists(file_path): os.makedirs(error_path, exist_ok=True) shutil.move(file_path, error_path) - # If quit halfway through processing, make sure we get rid of the bad file + # If quit halfway through processing, + # make sure we get rid of the bad file # if os.path.exists(flac_path): # os.remove(flac_path) - print("ERROR", "move to", os.path.join(error_path, Path(file_path).name), "ERR MSG:", e) + print( + "ERROR", + "move to", + os.path.join(error_path, Path(file_path).name), + "ERR MSG:", + e + ) row[col] = str(os.path.join(error_path, Path(file_path).name)) return row row[col] = str(flac_path) return row + @dataclass class XCParams(): """Parameters that describe ESC-50. @@ -97,6 +118,7 @@ class XCParams(): test_fold = 5 sample_rate = 44_100 + def xc_extractor( XC_dataset_json_path, parent_path, @@ -104,13 +126,18 @@ def xc_extractor( params: XCParams = XCParams(), bad_file_path="data/xc_bad_file" ): + """Extracts data collected from the XC downloader. + + XC_dataset_json_path: json outputted from XC downloader + parent_path: path to highest level audio file + cache_path: path to cache hugging + """ if os.path.exists(cache_path): return load_from_disk(cache_path) - with open(XC_dataset_json_path, mode="r") as f: xc_recordings_paged = json.load(f) - + xc_recordings = [] for page in xc_recordings_paged: xc_recordings.extend(page["recordings"]) @@ -130,17 +157,27 @@ def xc_extractor( dataset = dataset.add_column( "audio", [ - os.path.join(parent_path, file.replace("/", "_")) for file in dataset["file-name"] + os.path.join( + parent_path, + file.replace("/", "_") + ) for file in dataset["file-name"] ] ) # Fix file paths - dataset = dataset.map(convert_audio_to_flac, fn_kwargs={"error_path": bad_file_path}, num_proc=16) - dataset = dataset.filter(lambda x: not bad_file_path in x["audio"], num_proc=16) + dataset = dataset.map( + convert_audio_to_flac, + fn_kwargs={"error_path": bad_file_path}, + num_proc=16 + ) + dataset = dataset.filter( + lambda x: bad_file_path not in x["audio"], num_proc=16 + ) dataset = dataset.add_column("filepath", dataset["audio"]) - - - dataset = dataset.cast_column("audio", Audio(sampling_rate=params.sample_rate)) + dataset = dataset.cast_column( + "audio", + Audio(sampling_rate=params.sample_rate) + ) # TODO FIGURE OUT HOW TO DO SPLITS! # # Create splits of the data @@ -160,7 +197,10 @@ def xc_extractor( dataset = filter_by_count(dataset) train_test = dataset.train_test_split(0.2, stratify_by_column="en") - test_val = train_test["test"].train_test_split(0.2, stratify_by_column="en") + test_val = train_test["test"].train_test_split( + 0.2, + stratify_by_column="en" + ) dataset = AudioDataset( DatasetDict({ diff --git a/whoot_model_training/whoot_model_training/models/__init__.py b/whoot_model_training/whoot_model_training/models/__init__.py index 6bc07f3..da1f167 100644 --- a/whoot_model_training/whoot_model_training/models/__init__.py +++ b/whoot_model_training/whoot_model_training/models/__init__.py @@ -7,7 +7,11 @@ from .timm_model import TimmModel, TimmInputs, TimmModelConfig from .hf_models import HFModel, HFModelConfig, HFInput from .model import Model, ModelInput, ModelOutput -from .few_shot_model import PerchEmbeddingInput, PerchFewShotModel, FewShotModelConfig +from .few_shot_model import ( + PerchEmbeddingInput, + PerchFewShotModel, + FewShotModelConfig +) __all__ = [ "TimmModel", @@ -15,7 +19,7 @@ "TimmModelConfig", "HFModel", "HFModelConfig", - "HFInput" + "HFInput", "Model", "ModelInput", "ModelOutput", diff --git a/whoot_model_training/whoot_model_training/models/few_shot_model.py b/whoot_model_training/whoot_model_training/models/few_shot_model.py index a93d2fc..ab2bd11 100644 --- a/whoot_model_training/whoot_model_training/models/few_shot_model.py +++ b/whoot_model_training/whoot_model_training/models/few_shot_model.py @@ -1,51 +1,68 @@ """Build a few_shot_learning classifier. -Inspired by the work of -Jacuzzi, G., Olden, J.D., 2025. Few-shot transfer learning enables robust acoustic -monitoring of wildlife communities at the landscape scale. -Ecological Informatics 90, 103294. +Inspired by the work of +Jacuzzi, G., Olden, J.D., 2025. +Few-shot transfer learning enables robust acoustic +monitoring of wildlife communities at the landscape scale. +Ecological Informatics 90, 103294. doi.org/10.1016/j.ecoinf.2025.103294 -These models convert thier input into an embedding from a large audio model and +These models convert thier input into an embedding from a large audio model and do processing on top of that embedding """ -from .model import ModelInput, ModelOutput from torch import nn, Tensor from perch_hoplite.zoo import model_configs from .model import Model, ModelInput, ModelOutput, has_required_inputs from transformers import PretrainedConfig -## Common Classes class EmbeddingModel(): + """Wrapper for models which are only intended for embeddings.""" def embed(self): + """Get embedding.""" raise NotImplementedError() + class EmbeddingInput(ModelInput): + """Wrapper for ModelInputs that are embeddings.""" model = EmbeddingModel() embedding_size = 0 - def __init__(self, + def __init__( + self, labels, - waveform = None, - spectrogram = None): + waveform=None, + spectrogram=None + ): + """. + + Args: + labels: label + waveform: np array of sound + spectrogram: 2d array representing sound + """ super().__init__(labels, waveform, spectrogram) self["embedding"] = self.model.embed(waveform) -## Unique Models class PerchEmbeddings(EmbeddingModel): + """Wrapper for getting embeddings from perch.""" + model = model_configs.load_model_by_name('perch_8') - def embed(self, waveforms): + + def embed(self, embeddings): + """Return embeddings.""" # embeddings = [ # self.model.embed(waveform).embeddings[0] # for waveform in waveforms # ] - return waveforms + return embeddings + class PerchEmbeddingInput(EmbeddingInput): + """Wrapper for an input into a larger model from perch.""" model = PerchEmbeddings() embedding_size = 1280 @@ -60,12 +77,14 @@ def __init__( """Creates Config. Args: - + num_classes: how many species we want to detect """ self.num_classes = num_classes super().__init__(**kwargs) + class PerchFewShotModel(Model, nn.Module): + """Perch model intergration with pytorch.""" def __init__( self, config: FewShotModelConfig @@ -88,17 +107,22 @@ def __init__( self.config = config assert config.num_classes > 0 - # TODO BUILD MLP - self.linear = nn.Linear(self.input_format.embedding_size, config.num_classes) + self.linear = nn.Linear( + self.input_format.embedding_size, + config.num_classes + ) # TODO USE CUSTOM LOSS FOR FEW SHOW LEARNING self.loss = nn.BCEWithLogitsLoss() @has_required_inputs() def forward(self, x: PerchEmbeddingInput): + """Run model over x!""" # Use perch to create embeddings - embeddings = Tensor(x.model.model.embed(x["waveform"].cpu()).embeddings).to(x["waveform"].device) - + embeddings = Tensor( + x.model.model.embed(x["waveform"].cpu()).embeddings + ).to(x["waveform"].device) + logits = self.linear(embeddings).squeeze(1) loss = self.loss(logits, x["labels"]) @@ -108,5 +132,3 @@ def forward(self, x: PerchEmbeddingInput): loss=loss, labels=x["labels"] ) - - diff --git a/whoot_model_training/whoot_model_training/models/hf_models.py b/whoot_model_training/whoot_model_training/models/hf_models.py index 7ec538a..81f7e0e 100644 --- a/whoot_model_training/whoot_model_training/models/hf_models.py +++ b/whoot_model_training/whoot_model_training/models/hf_models.py @@ -1,17 +1,6 @@ -from transformers import AutoFeatureExtractor, AutoModel -import librosa - -"""Wrapper around the timms model zoo! - -See https://timm.fast.ai/ - -Timm model zoo good for computer vision models -Like CNNs, which are useful for spectrograms - -Great repo for models, but currently using this for demoing pipeline -""" +"""Wrapper around the hugging face model api!""" -import timm +from transformers import AutoFeatureExtractor, AutoModel from torch import nn import torch from contextlib import nullcontext @@ -21,41 +10,46 @@ class HFInput(): - """Input for TimmModels. + """Input for Hugging Face Models. Specifies TimmModels needs labels and spectrograms that are Tensors """ - def __init__(self, labels=None, spectrogram=None, waveform=None, extractor_path="DBD-research-group/Bird-MAE-Base"): + def __init__(self, + labels=None, + spectrogram=None, + waveform=None, + extractor_path="DBD-research-group/Bird-MAE-Base"): """Creates TimmInputs. Args: labels: the data's label for this batch - audio_data: some audio_data, basically this has a feature extractor for it + spectrogram: Legacy + waveform: Legacy + extractor_path: Path to hugging face preprocessor """ - - # print("fe works") - self.feature_extractor = AutoFeatureExtractor.from_pretrained(extractor_path, trust_remote_code=True) - # self.feature_extractor = AutoFeatureExtractor.from_pretrained(extractor_path, trust_remote_code=True) - - # mel_spectrogram = feature_extractor(waveform) - - # # # Can use inputs to verify correct shape for upstream model - # # assert spectrogram.shape[1:] == (1, 100, 100) - # self.labels = labels - # self.spectrogram = mel_spectrogram + self.feature_extractor = AutoFeatureExtractor.from_pretrained( + extractor_path, + trust_remote_code=True) + # TODO MAKE HFINPUT WORK WITH ITSELF def __call__(self, labels, spectrogram=None, waveform=None): + """Create some fake ModelInputs for HFModels. + + Slightly diffrent API for HFInput, when creating a input + Use the preprocessor from hugging face. + """ mel_spectrogram = self.feature_extractor(waveform) return ModelInput(labels, waveform=None, spectrogram=mel_spectrogram) - + + class HFModelConfig(PretrainedConfig): """Config for Timm Model Zoo Models!""" def __init__( self, - path="DBD-research-group/Bird-MAE-Huge", - num_classes=6, - embeddings_size=1280, - freeze_backbone = True, + path: str = "DBD-research-group/Bird-MAE-Huge", + num_classes: int = 6, + embeddings_size: int = 1280, + freeze_backbone: bool = True, **kwargs ): """Creates Config. @@ -64,6 +58,7 @@ def __init__( path (str): url to pull from hf model zoo num_classes (int): number of classes in dataset, for cls embeddings_size (int): size of output of model + freeze_backbone (bool): freeze the backbone of a model """ self.path = path self.num_classes = num_classes @@ -97,7 +92,10 @@ def __init__( assert config.num_classes > 0 # Deep learning CNN backbone - self.backbone = AutoModel.from_pretrained(config.path, trust_remote_code=True) + self.backbone = AutoModel.from_pretrained( + config.path, + trust_remote_code=True + ) # Unsure if 1000 is default for all timm models. Need to check this self.linear = nn.Linear(config.embeddings_size, config.num_classes) @@ -132,7 +130,9 @@ def forward(self, x: HFInput) -> ModelOutput: latent space representations (embeddings), loss and labels. """ with torch.no_grad() if self.config.freeze_backbone else nullcontext(): - embed = self.backbone(x.spectrogram.to(self.device)).last_hidden_state + embed = self.backbone( + x.spectrogram.to(self.device) + ).last_hidden_state logits = self.linear(embed) loss = self.loss(logits, x.labels) @@ -142,10 +142,3 @@ def forward(self, x: HFInput) -> ModelOutput: loss=loss, labels=x.labels ) - - - - - - - diff --git a/whoot_model_training/whoot_model_training/preprocessors/augmentations.py b/whoot_model_training/whoot_model_training/preprocessors/augmentations.py index 2c457cd..187b78a 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/augmentations.py +++ b/whoot_model_training/whoot_model_training/preprocessors/augmentations.py @@ -1,4 +1,5 @@ -"""Contains various data augementation techinques for bioacoustics +"""Contains various data augementation techinques for bioacoustics. + Notes: relies heavily on the audiomentions library Basically combine augmentations with ComposeAudioLabel @@ -8,6 +9,15 @@ For Devs: To create a new augmentation, create a AudioLabelPreprocessor """ -from pyha_analyzer.preprocessors.augmentations import ComposeAudioLabel, MixItUp, AudioLabelPreprocessor +from pyha_analyzer.preprocessors.augmentations import ( + ComposeAudioLabel, MixItUp, AudioLabelPreprocessor +) from audiomentations import Gain, PolarityInversion +__all__ = [ + "ComposeAudioLabel", + "MixItUp", + "AudioLabelPreprocessor", + "Gain", + "PolarityInversion" +] diff --git a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py index 8d492fe..d337098 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py +++ b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py @@ -163,8 +163,3 @@ def __call__(self, batch: dict) -> ModelInput: labels=batch["labels"], waveform=batch["audio"] ) - - - - - diff --git a/whoot_model_training/whoot_model_training/preprocessors/inferance_wrap.py b/whoot_model_training/whoot_model_training/preprocessors/inferance_wrap.py deleted file mode 100644 index 9b46e1e..0000000 --- a/whoot_model_training/whoot_model_training/preprocessors/inferance_wrap.py +++ /dev/null @@ -1,7 +0,0 @@ -class MelModelInputPreprocessor(): - def __init__(self, preprocessor): - self.preprocessor = preprocessor - - def __call__(self, batch_input): - assert bat - pass \ No newline at end of file diff --git a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py index 089164e..6598718 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py @@ -77,7 +77,10 @@ def __call__(self, batch): new_labels = [] for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] - y, sr = batch["audio"][item_idx]["array"],batch["audio"][item_idx]["sampling_rate"] + y, sr = ( + batch["audio"][item_idx]["array"], + batch["audio"][item_idx]["sampling_rate"] + ) start = 0 # Handle out of bound issues diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py index 5951110..a4e179b 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -6,7 +6,6 @@ import librosa import numpy as np -from torchvision import transforms from pyha_analyzer.preprocessors import PreProcessorBase @@ -56,6 +55,8 @@ def __init__( Args: duration (float): length of chunk of data to train on augments (Augmentations): An augmentation to apply to waveforms + sr (int/None): sample rate of audio to standize, + defaults to use file sr spectrogram_params (SpectrogramParams): config for spectrogram generation """ @@ -63,7 +64,8 @@ def __init__( self.augments = augments self.sr = sr - # # Below parameter defaults from https://arxiv.org/pdf/2403.10380 pg 25 + # # Below parameter defaults from + # # https://arxiv.org/pdf/2403.10380 pg 25 # self.n_fft = spectrogram_params.n_fft # self.hop_length = spectrogram_params.hop_length # self.power = spectrogram_params.power @@ -74,13 +76,15 @@ def __init__( def __call__(self, batch): """Process a batch of data from an AudioDataset.""" - # print("preprocessor", len(batch), len(batch["audio"]), len(batch["labels"])) new_audio = [] - new_labels = [] + new_labels = [] for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] try: - y, sr = librosa.load(path=batch["audio"][item_idx]["path"], sr=self.sr) + y, sr = librosa.load( + path=batch["audio"][item_idx]["path"], + sr=self.sr + ) except Exception as e: print(e) print("File Likely is corrupted, moving on") @@ -96,11 +100,11 @@ def __call__(self, batch): # Audio Based Augmentations if self.augments.audio is not None: y, label = self.augments.audio(y, sr, label) - + new_y = y[int(start * sr):end_sr] if (new_y.shape[-1] < int(sr * self.duration)): continue - + new_audio.append(new_y) new_labels.append(label) diff --git a/whoot_model_training/whoot_model_training/trainer.py b/whoot_model_training/whoot_model_training/trainer.py index e6a4cd8..35a22bf 100644 --- a/whoot_model_training/whoot_model_training/trainer.py +++ b/whoot_model_training/whoot_model_training/trainer.py @@ -18,9 +18,9 @@ from .dataset import AudioDataset from .models import Model import torch -import numpy as np from tqdm import tqdm + class WhootTrainingArguments(PyhaTrainingArguments): """Holds arguments use for training.""" def __init__(self, @@ -105,16 +105,31 @@ def __init__( preprocessor, model.output_format.ignore_keys ) + def predict( - self, test_dataset: AudioDataset, ignore_keys = None, metric_key_prefix: str = "test" - ): + self, + test_dataset: AudioDataset, + ignore_keys=None, + metric_key_prefix: str = "test" + ): + """Run Inferance on a given dataset. + Allows for getting predicted outputs to label a new dataset + Args: + test_dataset (AudioDataset): dataset to get preds from + This has labels but they are meaningless in this method + ignore_keys: N/A + metric_key_prefix: str = "test" + Returns: test_dataset with a new col: "pred" + """ test_dataloader = self.get_test_dataloader(test_dataset) - + preds = [] for batch in tqdm(test_dataloader): - preds.append(self.model(self.model.input_format(**batch))["logits"].detach().cpu()) + preds.append(self.model( + self.model.input_format(**batch) + )["logits"].detach().cpu()) dataset = test_dataset.to_dict() dataset["pred"] = torch.concat(preds).detach().numpy() - return dataset \ No newline at end of file + return dataset From d71134ff149c51ff96400439f7dd16458541d82f Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 17 Oct 2025 16:16:42 -0700 Subject: [PATCH 08/18] Remove perch model --- .../whoot_model_training/models/__init__.py | 10 +++++----- .../whoot_model_training/models/few_shot_model.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/whoot_model_training/whoot_model_training/models/__init__.py b/whoot_model_training/whoot_model_training/models/__init__.py index da1f167..c94c107 100644 --- a/whoot_model_training/whoot_model_training/models/__init__.py +++ b/whoot_model_training/whoot_model_training/models/__init__.py @@ -7,11 +7,11 @@ from .timm_model import TimmModel, TimmInputs, TimmModelConfig from .hf_models import HFModel, HFModelConfig, HFInput from .model import Model, ModelInput, ModelOutput -from .few_shot_model import ( - PerchEmbeddingInput, - PerchFewShotModel, - FewShotModelConfig -) +# from .few_shot_model import ( +# PerchEmbeddingInput, +# PerchFewShotModel, +# FewShotModelConfig +# ) __all__ = [ "TimmModel", diff --git a/whoot_model_training/whoot_model_training/models/few_shot_model.py b/whoot_model_training/whoot_model_training/models/few_shot_model.py index ab2bd11..0fa2234 100644 --- a/whoot_model_training/whoot_model_training/models/few_shot_model.py +++ b/whoot_model_training/whoot_model_training/models/few_shot_model.py @@ -50,7 +50,8 @@ def __init__( class PerchEmbeddings(EmbeddingModel): """Wrapper for getting embeddings from perch.""" - model = model_configs.load_model_by_name('perch_8') + # TODO FIX LINE TO BE LESS MEMORY HEAVY + # model = model_configs.load_model_by_name('perch_8') def embed(self, embeddings): """Return embeddings.""" From 6e6fe337a76251d3426808a33f7f8a5c56f6ff22 Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 31 Oct 2025 14:52:29 -0700 Subject: [PATCH 09/18] fix: adjust for nas data --- whoot_model_training/inferance.py | 6 ++---- .../whoot_model_training/data_extractor/xc_extractor.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/whoot_model_training/inferance.py b/whoot_model_training/inferance.py index c37f18e..4b4f0e2 100644 --- a/whoot_model_training/inferance.py +++ b/whoot_model_training/inferance.py @@ -12,13 +12,11 @@ config.yml should contain frequently changed hyperparameters """ -import os import argparse -import yaml from whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments -from whoot_model_training.data_extractor import buowset_extractor, raw_audio_extractor -from whoot_model_training.models import TimmModel, TimmInputs, TimmModelConfig +from whoot_model_training.data_extractor import raw_audio_extractor +from whoot_model_training.models import TimmModel, TimmInputs from whoot_model_training import CometMLLoggerSupplement from train import parse_config, init_env diff --git a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py index 4611734..ad78be5 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py @@ -168,10 +168,10 @@ def xc_extractor( dataset = dataset.map( convert_audio_to_flac, fn_kwargs={"error_path": bad_file_path}, - num_proc=16 + # num_proc=16 ) dataset = dataset.filter( - lambda x: bad_file_path not in x["audio"], num_proc=16 + lambda x: bad_file_path not in x["audio"], #num_proc=16 ) dataset = dataset.add_column("filepath", dataset["audio"]) dataset = dataset.cast_column( From f0f414b5268312116c9a17384c57237a3a54139d Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 31 Oct 2025 15:13:48 -0700 Subject: [PATCH 10/18] lint: finsh flake8 linting --- pyproject.toml | 1 + whoot_model_training/inferance.py | 17 +-- whoot_model_training/train.py | 61 +++++------ .../data_extractor/Jacuzzi_Olden_extractor.py | 101 ++++++++---------- .../data_extractor/__init__.py | 24 ++--- .../data_extractor/raw_audio_extractor.py | 99 +++++++++++------ .../data_extractor/xc_extractor.py | 2 +- .../models/few_shot_model.py | 2 +- 8 files changed, 162 insertions(+), 145 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e8b2174..2f57094 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">= 3.10.0, < 3.13.0" dependencies = [ "librosa>=0.10.2.post1", "numba==0.61.0", + "nvitop>=1.5.3", "pandas>=2.3.0", "pydub>=0.25.1", "python-dotenv>=1.1.1", diff --git a/whoot_model_training/inferance.py b/whoot_model_training/inferance.py index 4b4f0e2..fc09081 100644 --- a/whoot_model_training/inferance.py +++ b/whoot_model_training/inferance.py @@ -26,7 +26,12 @@ import pickle -def test(config, model_name=""): + +def test( + config, + model_name="", + audio_dir="/mnt/restorage/Audiomoth/Raw sound files/2024/RGCB/" +): """Highest level logic for inferance. Does the following: @@ -38,11 +43,12 @@ def test(config, model_name=""): Args: config (dict): the config used for training. Defined in yaml file - TODO + model_name (str): path to model checkpoint to use + audio_dir (str): path to unlabeled data """ # Extract a new dataset ds = raw_audio_extractor( - audio_parent_folder="/mnt/restorage/Audiomoth/Raw sound files/2024/RGCB/", + audio_parent_folder=audio_dir, output_folder="data/manual_buowset", chunk_duration=3 ) @@ -64,12 +70,10 @@ def test(config, model_name=""): # ds["valid"].set_transform(preprocessor) # ds["test"].set_transform(preprocessor) - model_name = "efficientnet_b1" run_name = f"buowset1.1_{model_name}_ATTEMPT_TO_STUDY_NEW_DATA" # trainer = WhootTrainer._load_from_checkpoint(model_name) - # Run training training_args = WhootTrainingArguments( run_name=run_name, @@ -105,11 +109,12 @@ def test(config, model_name=""): # trainer.evaluate(ds["test"], metric_key_prefix="test") # trainer.evaluate(ds["valid"], metric_key_prefix="valid") + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Input config path") parser.add_argument("config", type=str, help="Path to config.yml") parser.add_argument( - "--model_name", + "--model_name", required=False, help="path to weights or hugging face repo id", default="/home/sean/whoot/checkpoint-4985") diff --git a/whoot_model_training/train.py b/whoot_model_training/train.py index 1566cb5..7b61d53 100644 --- a/whoot_model_training/train.py +++ b/whoot_model_training/train.py @@ -17,19 +17,15 @@ import yaml from whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments -from whoot_model_training.data_extractor import buowset_extractor +from whoot_model_training.data_extractor import xc_extractor from whoot_model_training.models import TimmModel, TimmInputs, TimmModelConfig from whoot_model_training import CometMLLoggerSupplement from whoot_model_training.preprocessors import ( MelModelInputPreprocessor ) -from whoot_model_training.preprocessors.spectrogram_preprocessors import SpectrogramParams -from whoot_model_training.preprocessors.augmentations import ( - Gain, - PolarityInversion, - MixItUp, - ComposeAudioLabel +from whoot_model_training.preprocessors.spectrogram_preprocessors import ( + SpectrogramParams ) # Uncomment for use with data augmentation @@ -75,10 +71,9 @@ def train(config): # output_path=config["hf_cache_path"], # ) - from whoot_model_training.whoot_model_training.data_extractor import xc_extractor - + csv_path = "/home/sean/whoot/data/san_diego_xc_aux/xc_meta_aux.json" ds = xc_extractor( - XC_dataset_json_path="/home/sean/whoot/data/san_diego_xc_aux/xc_meta_aux.json", + XC_dataset_json_path=csv_path, parent_path="/home/sean/whoot/data/san_diego_xc_aux/xeno-canto" ) @@ -93,29 +88,29 @@ def train(config): # Preprocessors - # Uncomment if doing work with data augmentation - # Augmentations - wav_augs = ComposeAudioLabel([ - # AddBackgroundNoise( #We don't have background noise yet... - # sounds_path="data_birdset/background_noise", - # min_snr_db=10, - # max_snr_db=30, - # noise_transform=PolarityInversion(), - # p=0.8 - # ), - Gain( - min_gain_db = -12, - max_gain_db = 12, - p = 0.8 - ), - # MixItUp( - # dataset_ref=ds["train"], - # min_snr_db=10, - # max_snr_db=30, - # noise_transform=PolarityInversion(), - # p=0.8 - # ) - ]) + # # Uncomment if doing work with data augmentation + # # Augmentations + # wav_augs = ComposeAudioLabel([ + # # AddBackgroundNoise( #We don't have background noise yet... + # # sounds_path="data_birdset/background_noise", + # # min_snr_db=10, + # # max_snr_db=30, + # # noise_transform=PolarityInversion(), + # # p=0.8 + # # ), + # Gain( + # min_gain_db=-12, + # max_gain_db=12, + # p = 0.8 + # ), + # # MixItUp( + # # dataset_ref=ds["train"], + # # min_snr_db=10, + # # max_snr_db=30, + # # noise_transform=PolarityInversion(), + # # p=0.8 + # # ) + # ]) spectrogram_params = SpectrogramParams() # spectrogram_params = SpectrogramParams( diff --git a/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py index d79fea3..fe16cc3 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py @@ -3,60 +3,47 @@ See data_downloader/xc.py """ -import os -from dataclasses import dataclass - -import numpy as np -from datasets import ( - load_dataset, - Dataset, - Audio, - DatasetDict, - ClassLabel, - Sequence, -) -from ..dataset import AudioDataset - -import json -import pandas as pd - - -def one_hot_encode(row: dict, classes: list): - """One hot Encodes a list of labels. - - Args: - row (dict): row of data in a dataset containing a labels column - classes: a list of classes - """ - one_hot = np.zeroes(len(classes)) - one_hot[row["labels"]] = 1 - row["labels"] = np.array(one_hot, dtype=float) - return row - -def Jacuzzi_Olden_Extractor( - root_path -): - audio_path = f"{root_path}/training/audio" - train_df = pd.read_csv(f"{root_path}/training/training_data_annotations.csv") - train_df["labels"] = train_df["labels"].str.split(",") - train_df["file_path"] = train_df["audio_subdir"].apply( - lambda folder: f"{audio_path}/{folder}/" - ) + train_df["file"].apply(lambda path: path + ".wav") - - test_df = pd.read_csv(f"{root_path}/test/test_data_annotations.csv") - test_df["labels"] = test_df["labels"].str.split(",") - test_df["file"] = test_df["file"].str.findall( - r"-0.\d+_([\w.]+).wav").apply(lambda x: x[0]) - test_df["file_path"] = test_df["focal_class"].apply( - lambda folder: f"{audio_path}/{folder}/" - ) + test_df["file"].apply(lambda path: path + ".wav") - - return train_df, test_df - - # TODO - # Convert to AudioDataset - # Convert Labels to right format - # Convert audio type - # Done - - +# import numpy as np +# from ..dataset import AudioDataset + +# import json +# import pandas as pd + + +# def one_hot_encode(row: dict, classes: list): +# """One hot Encodes a list of labels. + +# Args: +# row (dict): row of data in a dataset containing a labels column +# classes: a list of classes +# """ +# one_hot = np.zeroes(len(classes)) +# one_hot[row["labels"]] = 1 +# row["labels"] = np.array(one_hot, dtype=float) +# return row + +# def Jacuzzi_Olden_Extractor(root_path): +# audio_path = f"{root_path}/training/audio" +# train_df = pd.read_csv( +# f"{root_path}/training/training_data_annotations.csv" +# ) +# train_df["labels"] = train_df["labels"].str.split(",") +# train_df["file_path"] = train_df["audio_subdir"].apply( +# lambda folder: f"{audio_path}/{folder}/" +# ) + train_df["file"].apply(lambda path: path + ".wav") + +# test_df = pd.read_csv(f"{root_path}/test/test_data_annotations.csv") +# test_df["labels"] = test_df["labels"].str.split(",") +# test_df["file"] = test_df["file"].str.findall( +# r"-0.\d+_([\w.]+).wav").apply(lambda x: x[0]) +# test_df["file_path"] = test_df["focal_class"].apply( +# lambda folder: f"{audio_path}/{folder}/" +# ) + test_df["file"].apply(lambda path: path + ".wav") + +# return train_df, test_df + +# # TODO +# # Convert to AudioDataset +# # Convert Labels to right format +# # Convert audio type +# # Done diff --git a/whoot_model_training/whoot_model_training/data_extractor/__init__.py b/whoot_model_training/whoot_model_training/data_extractor/__init__.py index 9b3ef7f..8b06cb1 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/__init__.py +++ b/whoot_model_training/whoot_model_training/data_extractor/__init__.py @@ -12,21 +12,17 @@ from .Jacuzzi_Olden_extractor import Jacuzzi_Olden_Extractor from .xc_extractor import xc_extractor -__all__ = ["buowset_extractor", "buowset_binary_extractor", "esc50_extractor", "Jacuzzi_Olden_Extractor", "xc_extractor", "raw_audio_extractor"] +__all__ = [ + "buowset_extractor", + "buowset_binary_extractor", + "esc50_extractor", + "Jacuzzi_Olden_Extractor", + "xc_extractor", + "raw_audio_extractor" +] + def concat_dataset(datasetA, datasetB): for split in datasetA.keys(): pass - - #TODO FIGURE OUT HOW TO SAFETLY COMBINE TWO DATASETS - - # labels - # this is tricky, you need to check class names for union, then - # Apply annotations accordingly - # maybe use a dict to handle classes in both datasets - - # Audio - # should be able to merge - - # Metadata - # Consider dropping all non-required columns, will make merge easier + # TODO FIGURE OUT HOW TO SAFETLY COMBINE TWO DATASETS diff --git a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py index ceafe77..6139df8 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py @@ -2,16 +2,15 @@ Extractor for general, typically unlabeled soundscape recordings -Fits as much as possible to the AudioDataset standard but +Fits as much as possible to the AudioDataset standard but NOT INTENDED FOR TRAINING Rather just a placeholder to help inferance work """ -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Union import os import numpy as np from datasets import ( - load_dataset, Audio, concatenate_datasets, DatasetDict, @@ -24,11 +23,11 @@ from math import floor from tqdm import tqdm import pyarrow as pa +from datasets.features.features import _FEATURE_TYPES, FeatureType from ..dataset import AudioDataset - class SubAudio(Audio): """Extends Audio to take a chunks of data. @@ -36,7 +35,7 @@ class SubAudio(Audio): https://github.com/huggingface/datasets/blob/5dc1a179783dff868b0547c8486268cfaea1ea1f/src/datasets/features/audio.py#L24 The Audio Column of a HuggingFace dataset - handles loading in data from a given file + handles loading in data from a given file What is nice is it streams data: it doesn't get loaded into memory until it is needed via the path @@ -45,40 +44,52 @@ class SubAudio(Audio): We would need to load it as an array instead of a path And it gets loaded into memory. Huge issue with large audio datasets. - By default HF doesn't support chunking, so this class should handle chunking + By default HF doesn't support chunking, + so this class should handle chunking During streaming rather than during dataset creation - You can use it the same way you might with the Audio class. In fact, with normal processing + You can use it the same way you might with the Audio class. + In fact, with normal processing it handles the same way! - To use the chunking feature, create a Audio row with the following parameters + To use the chunking feature, create a Audio row with the + following parameters - path: as is with Audio - sampling_rate: as is with Audio - - offset: NEW, offset in seconds of when to start taking audio data - - duration: NEW, duration from offset in seconds for how much data to collect + - offset: NEW, offset in seconds of when to start taking + audio data + - duration: NEW, duration from offset in seconds for + how much data to collect - You need both offset and duration to load in the chunk, otherwise it will load the full file. + You need both offset and duration to load in the chunk, + otherwise it will load the full file. """ pa_type: ClassVar[Any] = pa.struct({ - "bytes": pa.binary(), + "bytes": pa.binary(), "path": pa.string(), "offset": pa.int64(), "duration": pa.int64() }) def __call__(self): + """Get type.""" return self.pa_type def encode_example(self, value) -> dict: + """Encode audio data to raw.""" if ( isinstance(value, dict) and value.get("offset") and value.get("duration") - and value.get("path") is not None + and value.get("path") is not None and os.path.isfile(value["path"]) ): - y, sr = librosa.load(path = value["path"], offset=value["offset"], duration=value["duration"]) + y, sr = librosa.load( + path=value["path"], + offset=value["offset"], + duration=value["duration"] + ) value["array"] = y value["sampling_rate"] = sr encoded = super().encode_example(value) @@ -88,20 +99,24 @@ def encode_example(self, value) -> dict: return encoded return super().encode_example(value) - def decode_example(self, value, token_per_repo_id=None) -> dict: - # print("d4ecode", value) + def decode_example(self, value, token_per_repo_id=None) -> dict: + """Decode raw data into audio info and array.""" if ( isinstance(value, dict) and "offset" in value and "duration" in value - and value.get("bytes") is None - and value.get("path") is not None + and value.get("bytes") is None + and value.get("path") is not None and os.path.isfile(value["path"]) ): - y, sr = librosa.load(path = value["path"], offset=value["offset"], duration=value["duration"]) + y, sr = librosa.load( + path=value["path"], + offset=value["offset"], + duration=value["duration"] + ) return { - "path": value["path"], - "array": y, + "path": value["path"], + "array": y, "sampling_rate": sr, "offset": value["offset"], "duration": value["duration"]} @@ -109,16 +124,25 @@ def decode_example(self, value, token_per_repo_id=None) -> dict: isinstance(value, dict) and value.get("offset") and value.get("duration") - and value.get("bytes") is not None + and value.get("bytes") is not None ): - decoded = super().decode_example(value, token_per_repo_id=token_per_repo_id) - decoded["offset"] = value["offset"] - decoded["duration"] = value["duration"] + decoded = super().decode_example( + value, + token_per_repo_id=token_per_repo_id + ) + decoded["offset"] = value["offset"] + decoded["duration"] = value["duration"] return decoded - return super().decode_example(value, token_per_repo_id=token_per_repo_id) - - def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.StructArray: - # print("cast_storage real") + return super().decode_example( + value, + token_per_repo_id=token_per_repo_id + ) + + def cast_storage( + self, + storage: Union[pa.StringArray, pa.StructArray] + ) -> pa.StructArray: + """Cast a hugging face dataset column as the data type.""" if pa.types.is_struct(storage.type): if storage.type.get_field_index("bytes") >= 0: bytes_array = storage.field("bytes") @@ -135,17 +159,25 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.Str if storage.type.get_field_index("duration") >= 0: duration_array = storage.field("duration") else: - duration_array = pa.array([None] * len(storage), type=pa.int64()) - storage = pa.StructArray.from_arrays([bytes_array, path_array, offset_array, duration_array], ["bytes", "path", "offset", "duration"], mask=storage.is_null()) + duration_array = pa.array( + [None] * len(storage), + type=pa.int64() + ) + storage = pa.StructArray.from_arrays( + [bytes_array, path_array, offset_array, duration_array], + ["bytes", "path", "offset", "duration"], + mask=storage.is_null() + ) return table.array_cast(storage, self.pa_type) -from datasets.features.features import _FEATURE_TYPES, FeatureType + _FEATURE_TYPES[SubAudio.__name__] = SubAudio FeatureType = Union[FeatureType, SubAudio] def get_empty_dict(): - return { + """Helper to make a new row.""" + return { "audio": [], "file_path": [], "labels": [], @@ -158,6 +190,7 @@ def get_array_chunks_from_memory( no_class_idx=5, output_path="/data/manual_buowset" ): + """Split audio data into chunks and save each as SubAudio data.""" new_rows = get_empty_dict() _datasets = [] for root, dirs, files in tqdm(os.walk(parent_folder), desc="All Folders"): diff --git a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py index ad78be5..95a2d73 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py @@ -171,7 +171,7 @@ def xc_extractor( # num_proc=16 ) dataset = dataset.filter( - lambda x: bad_file_path not in x["audio"], #num_proc=16 + lambda x: bad_file_path not in x["audio"], ) dataset = dataset.add_column("filepath", dataset["audio"]) dataset = dataset.cast_column( diff --git a/whoot_model_training/whoot_model_training/models/few_shot_model.py b/whoot_model_training/whoot_model_training/models/few_shot_model.py index 0fa2234..89a4eb0 100644 --- a/whoot_model_training/whoot_model_training/models/few_shot_model.py +++ b/whoot_model_training/whoot_model_training/models/few_shot_model.py @@ -12,7 +12,7 @@ """ from torch import nn, Tensor -from perch_hoplite.zoo import model_configs +# from perch_hoplite.zoo import model_configs from .model import Model, ModelInput, ModelOutput, has_required_inputs from transformers import PretrainedConfig From 1cd04bd6e6afbb1bc67a2bbed69e917af9f049d7 Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 12 Dec 2025 15:31:10 -0800 Subject: [PATCH 11/18] Clean up few-shot experiment branch --- .gitignore | 2 + whoot_model_training/inferance.py | 9 ++-- .../data_extractor/Jacuzzi_Olden_extractor.py | 49 ------------------- .../data_extractor/__init__.py | 8 --- .../data_extractor/raw_audio_extractor.py | 20 ++++---- .../data_extractor/xc_extractor.py | 30 ++++++------ .../whoot_model_training/models/__init__.py | 4 +- .../models/few_shot_model.py | 16 +++--- .../spectrogram_preprocessors.py | 20 +++++--- .../preprocessors/waveform_preprocessors.py | 5 +- .../whoot_model_training/trainer.py | 5 +- 11 files changed, 61 insertions(+), 107 deletions(-) delete mode 100644 whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py diff --git a/.gitignore b/.gitignore index d5ced77..53aff3d 100644 --- a/.gitignore +++ b/.gitignore @@ -227,3 +227,5 @@ settings.json whoot_model_training/configs !whoot_model_training/configs/config.yml *.csv +*.ipynb +*.json \ No newline at end of file diff --git a/whoot_model_training/inferance.py b/whoot_model_training/inferance.py index fc09081..137cdfd 100644 --- a/whoot_model_training/inferance.py +++ b/whoot_model_training/inferance.py @@ -13,18 +13,21 @@ config.yml should contain frequently changed hyperparameters """ import argparse +import pickle + +from train import parse_config, init_env + +from whoot_model_training.preprocessors.base_preprocessor import WaveformInputPreprocessor from whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments from whoot_model_training.data_extractor import raw_audio_extractor from whoot_model_training.models import TimmModel, TimmInputs from whoot_model_training import CometMLLoggerSupplement -from train import parse_config, init_env - from whoot_model_training.preprocessors import ( MelModelInputPreprocessor ) -import pickle + def test( diff --git a/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py deleted file mode 100644 index fe16cc3..0000000 --- a/whoot_model_training/whoot_model_training/data_extractor/Jacuzzi_Olden_extractor.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Ceates Dataset from the Xeno-Canto Data Downlaoder tool. - -See data_downloader/xc.py -""" - -# import numpy as np -# from ..dataset import AudioDataset - -# import json -# import pandas as pd - - -# def one_hot_encode(row: dict, classes: list): -# """One hot Encodes a list of labels. - -# Args: -# row (dict): row of data in a dataset containing a labels column -# classes: a list of classes -# """ -# one_hot = np.zeroes(len(classes)) -# one_hot[row["labels"]] = 1 -# row["labels"] = np.array(one_hot, dtype=float) -# return row - -# def Jacuzzi_Olden_Extractor(root_path): -# audio_path = f"{root_path}/training/audio" -# train_df = pd.read_csv( -# f"{root_path}/training/training_data_annotations.csv" -# ) -# train_df["labels"] = train_df["labels"].str.split(",") -# train_df["file_path"] = train_df["audio_subdir"].apply( -# lambda folder: f"{audio_path}/{folder}/" -# ) + train_df["file"].apply(lambda path: path + ".wav") - -# test_df = pd.read_csv(f"{root_path}/test/test_data_annotations.csv") -# test_df["labels"] = test_df["labels"].str.split(",") -# test_df["file"] = test_df["file"].str.findall( -# r"-0.\d+_([\w.]+).wav").apply(lambda x: x[0]) -# test_df["file_path"] = test_df["focal_class"].apply( -# lambda folder: f"{audio_path}/{folder}/" -# ) + test_df["file"].apply(lambda path: path + ".wav") - -# return train_df, test_df - -# # TODO -# # Convert to AudioDataset -# # Convert Labels to right format -# # Convert audio type -# # Done diff --git a/whoot_model_training/whoot_model_training/data_extractor/__init__.py b/whoot_model_training/whoot_model_training/data_extractor/__init__.py index 8b06cb1..b9db3ce 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/__init__.py +++ b/whoot_model_training/whoot_model_training/data_extractor/__init__.py @@ -9,20 +9,12 @@ ) from .esc50_extractor import esc50_extractor from .raw_audio_extractor import raw_audio_extractor -from .Jacuzzi_Olden_extractor import Jacuzzi_Olden_Extractor from .xc_extractor import xc_extractor __all__ = [ "buowset_extractor", "buowset_binary_extractor", "esc50_extractor", - "Jacuzzi_Olden_Extractor", "xc_extractor", "raw_audio_extractor" ] - - -def concat_dataset(datasetA, datasetB): - for split in datasetA.keys(): - pass - # TODO FIGURE OUT HOW TO SAFETLY COMBINE TWO DATASETS diff --git a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py index 6139df8..5e81409 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py @@ -7,6 +7,7 @@ Rather just a placeholder to help inferance work """ +from math import floor from typing import Any, ClassVar, Union import os import numpy as np @@ -20,7 +21,6 @@ table ) import librosa -from math import floor from tqdm import tqdm import pyarrow as pa from datasets.features.features import _FEATURE_TYPES, FeatureType @@ -101,10 +101,14 @@ def encode_example(self, value) -> dict: def decode_example(self, value, token_per_repo_id=None) -> dict: """Decode raw data into audio info and array.""" - if ( + correct_dict = ( isinstance(value, dict) and "offset" in value and "duration" in value + ) + + if ( + correct_dict and value.get("bytes") is None and value.get("path") is not None and os.path.isfile(value["path"]) @@ -120,7 +124,7 @@ def decode_example(self, value, token_per_repo_id=None) -> dict: "sampling_rate": sr, "offset": value["offset"], "duration": value["duration"]} - elif ( + if ( isinstance(value, dict) and value.get("offset") and value.get("duration") @@ -187,13 +191,12 @@ def get_empty_dict(): def get_array_chunks_from_memory( parent_folder, chunk_length_sec=5, - no_class_idx=5, - output_path="/data/manual_buowset" + no_class_idx=5 ): """Split audio data into chunks and save each as SubAudio data.""" new_rows = get_empty_dict() _datasets = [] - for root, dirs, files in tqdm(os.walk(parent_folder), desc="All Folders"): + for root, _, files in tqdm(os.walk(parent_folder), desc="All Folders"): for filename in tqdm(files, leave=False, desc="file in dir"): try: if not filename.lower().endswith( @@ -224,7 +227,6 @@ def get_array_chunks_from_memory( }) new_rows["file_path"].append(filename) new_rows["labels"].append(no_class_idx) - # break #TODO REMOVE except BaseException as e: print(e) @@ -235,9 +237,6 @@ def get_array_chunks_from_memory( ).cast_column("audio", SubAudio()) new_rows = get_empty_dict() _datasets.append(file_ds) - # break #TODO REMOVE - # if len(_datasets) > 1: #TODO REMOVE - # break #TODO REMOVE return concatenate_datasets(_datasets) @@ -267,7 +266,6 @@ def raw_audio_extractor( "no_buow" ], chunk_duration=-1, - output_folder="" ): """Extracts raw, unlabeled data in the buowset format into an AudioDataset. diff --git a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py index 95a2d73..62f665d 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py @@ -9,7 +9,8 @@ from dataclasses import dataclass from collections import Counter from pydub import AudioSegment - +import json +import librosa import numpy as np from datasets import ( @@ -22,9 +23,6 @@ ) from ..dataset import AudioDataset -import json -import librosa - def filter_by_count(ds, col="en", threshold=10): """Limit species list to species with some amount of species.""" @@ -76,6 +74,7 @@ def convert_audio_to_flac(row, error_path="bad_files", col="audio"): """ file_path = row[col] flac_path = Path(file_path).parent / (Path(file_path).stem + ".flac") + # print(file_path, flac_path) if os.path.exists(flac_path): row[col] = str(flac_path) if os.path.exists(file_path): @@ -164,32 +163,33 @@ def xc_extractor( ] ) + + # Only accept less than 10 min long clips + # Longer clips seem to courrpt more easily... + # Format is "#:##"" hence length 4 + dataset = dataset.filter( + lambda x: len(x["length"]) == 4 + ) + # Fix file paths dataset = dataset.map( convert_audio_to_flac, fn_kwargs={"error_path": bad_file_path}, # num_proc=16 ) + + + dataset = dataset.filter( lambda x: bad_file_path not in x["audio"], ) + dataset = dataset.add_column("filepath", dataset["audio"]) dataset = dataset.cast_column( "audio", Audio(sampling_rate=params.sample_rate) ) - # TODO FIGURE OUT HOW TO DO SPLITS! - # # Create splits of the data - # test_ds = dataset.filter(lambda x: x["fold"] == params.test_fold) - # valid_ds = dataset.filter(lambda x: x["fold"] == params.validation_fold) - # train_ds = dataset.filter( - # lambda x: ( - # x["fold"] != params.test_fold - # and x["fold"] != params.validation_fold - # ) - # ) - dataset = dataset.cast_column( "en", ClassLabel(names=list(set(dataset["en"]))) ) diff --git a/whoot_model_training/whoot_model_training/models/__init__.py b/whoot_model_training/whoot_model_training/models/__init__.py index c94c107..55aca50 100644 --- a/whoot_model_training/whoot_model_training/models/__init__.py +++ b/whoot_model_training/whoot_model_training/models/__init__.py @@ -23,7 +23,7 @@ "Model", "ModelInput", "ModelOutput", - "PerchEmbeddingInput", - "PerchFewShotModel", + # "PerchEmbeddingInput", + # "PerchFewShotModel", "FewShotModelConfig" ] diff --git a/whoot_model_training/whoot_model_training/models/few_shot_model.py b/whoot_model_training/whoot_model_training/models/few_shot_model.py index 89a4eb0..33a4c06 100644 --- a/whoot_model_training/whoot_model_training/models/few_shot_model.py +++ b/whoot_model_training/whoot_model_training/models/few_shot_model.py @@ -46,19 +46,22 @@ def __init__( self["embedding"] = self.model.embed(waveform) +# Global variable fore PerchEmbeddings +perch_model = None class PerchEmbeddings(EmbeddingModel): """Wrapper for getting embeddings from perch.""" - # TODO FIX LINE TO BE LESS MEMORY HEAVY - # model = model_configs.load_model_by_name('perch_8') + # Warning, was running into issues with memory here + # Early attempts recreated model + # Hoping using global var only loads it in once + if perch_model is None: + perch_model = model_configs.load_model_by_name('perch_8') + + model = perch_model def embed(self, embeddings): """Return embeddings.""" - # embeddings = [ - # self.model.embed(waveform).embeddings[0] - # for waveform in waveforms - # ] return embeddings @@ -113,7 +116,6 @@ def __init__( config.num_classes ) - # TODO USE CUSTOM LOSS FOR FEW SHOW LEARNING self.loss = nn.BCEWithLogitsLoss() @has_required_inputs() diff --git a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py index 6598718..5259c98 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py @@ -94,17 +94,21 @@ def __call__(self, batch): pillow_transforms = transforms.ToPILImage() + S = librosa.feature.melspectrogram( + y=y[int(start * sr):end_sr], + sr=sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + power=self.power, + n_mels=self.n_mels, + ) + pcen_S = librosa.pcen(S * (2**31)) + + mels = ( np.array( pillow_transforms( - librosa.feature.melspectrogram( - y=y[int(start * sr):end_sr], - sr=sr, - n_fft=self.n_fft, - hop_length=self.hop_length, - power=self.power, - n_mels=self.n_mels, - ) + pcen_S ), np.float32, )[np.newaxis, ::] diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py index a4e179b..03f1203 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -81,6 +81,9 @@ def __call__(self, batch): for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] try: + #TEMP FIX TO PREVENT LONG AUDIO RECORDINGS BEING BAD!!! + if librosa.get_duration(path=batch["audio"][item_idx]["path"]) > 2 * 60: + continue y, sr = librosa.load( path=batch["audio"][item_idx]["path"], sr=self.sr @@ -102,7 +105,7 @@ def __call__(self, batch): y, label = self.augments.audio(y, sr, label) new_y = y[int(start * sr):end_sr] - if (new_y.shape[-1] < int(sr * self.duration)): + if new_y.shape[-1] < int(sr * self.duration): continue new_audio.append(new_y) diff --git a/whoot_model_training/whoot_model_training/trainer.py b/whoot_model_training/whoot_model_training/trainer.py index 35a22bf..97c7271 100644 --- a/whoot_model_training/whoot_model_training/trainer.py +++ b/whoot_model_training/whoot_model_training/trainer.py @@ -10,15 +10,14 @@ from datetime import datetime import os - +import torch +from tqdm import tqdm from pyha_analyzer import PyhaTrainingArguments from pyha_analyzer import PyhaTrainer from .metrics import WhootMutliClassMetrics from .dataset import AudioDataset from .models import Model -import torch -from tqdm import tqdm class WhootTrainingArguments(PyhaTrainingArguments): From afca698a8e260672ba543dd421e3d4d3dfa349fe Mon Sep 17 00:00:00 2001 From: sean1572 Date: Wed, 17 Dec 2025 17:08:57 -0800 Subject: [PATCH 12/18] Add inferance fixes for fewshot --- pyproject.toml | 4 - test.py | 79 +++++++++++++++++++ whoot_model_training/inferance.py | 38 ++++++--- .../data_extractor/raw_audio_extractor.py | 3 + .../preprocessors/waveform_preprocessors.py | 20 +++-- .../whoot_model_training/trainer.py | 31 ++++++-- 6 files changed, 148 insertions(+), 27 deletions(-) create mode 100644 test.py diff --git a/pyproject.toml b/pyproject.toml index 1e3be47..2f57094 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,10 +59,6 @@ notebooks = [ "seaborn>=0.13.2", ] -birdnet = [ - "birdnet>=0.1.7", -] - [packages.index] cu128 = "https://download.pytorch.org/whl/cu128" diff --git a/test.py b/test.py new file mode 100644 index 0000000..9423285 --- /dev/null +++ b/test.py @@ -0,0 +1,79 @@ +# # %% +# %load_ext autoreload +# %autoreload 1 + +# %% + +from whoot_model_training.whoot_model_training.preprocessors import WaveformInputPreprocessor +from whoot_model_training.whoot_model_training.models import HFInput, HFModel, HFModelConfig +from whoot_model_training.whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments +from whoot_model_training.whoot_model_training.data_extractor import xc_extractor +from whoot_model_training.whoot_model_training import CometMLLoggerSupplement + + +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# %% +ds = xc_extractor( + XC_dataset_json_path="xc_meta_aux.json", + parent_path="/mnt/acoustics/san_diego_xc_aux/xeno-canto" +) + + + +model = HFModel(HFModelConfig(num_classes=ds.get_number_species())) + +# %% +# %% + +input_wrapper = HFInput() + +train_preprocessor = WaveformInputPreprocessor( + input_wrapper, duration=3 +) + +preprocessor = WaveformInputPreprocessor( + input_wrapper, duration=3 +) + +ds["train"].set_transform(train_preprocessor) +ds["valid"].set_transform(preprocessor) +ds["test"].set_transform(preprocessor) + +print(ds.get_class_labels()) + +# run_name = "fewshot_test_birdmae" +# subproject_name = "fewshot_test" +# dataset_name = "san_diego_xc_aux_09_2025" + +# training_args = WhootTrainingArguments( +# run_name=run_name, +# subproject_name=subproject_name, +# dataset_name=dataset_name, +# ) + +# # COMMON OPTIONAL ARGS +# training_args.num_train_epochs = 100 +# training_args.eval_steps = 2000 +# training_args.per_device_train_batch_size = 16 +# training_args.per_device_eval_batch_size = 16 +# training_args.dataloader_num_workers = 16 +# training_args.run_name = run_name +# training_args.learning_rate = 0.01 +# training_args.save_strategy="steps", # Save at the end of each epoch +# training_args.save_total_limit=2 # Keep only the last 2 checkpoints + +# trainer = WhootTrainer( +# model=model, +# dataset=ds, +# training_args=training_args, +# logger=CometMLLoggerSupplement( +# augmentations=None, +# name=training_args.run_name +# ), +# ) + +# trainer.train() +# model.save_pretrained("model_checkpoints/fewshot_test_birdmae") + diff --git a/whoot_model_training/inferance.py b/whoot_model_training/inferance.py index 137cdfd..b1b51e1 100644 --- a/whoot_model_training/inferance.py +++ b/whoot_model_training/inferance.py @@ -14,6 +14,7 @@ """ import argparse import pickle +import datasets from train import parse_config, init_env @@ -22,6 +23,7 @@ from whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments from whoot_model_training.data_extractor import raw_audio_extractor from whoot_model_training.models import TimmModel, TimmInputs +from whoot_model_training.models import HFInput, HFModel, HFModelConfig from whoot_model_training import CometMLLoggerSupplement from whoot_model_training.preprocessors import ( MelModelInputPreprocessor @@ -53,7 +55,8 @@ def test( ds = raw_audio_extractor( audio_parent_folder=audio_dir, output_folder="data/manual_buowset", - chunk_duration=3 + chunk_duration=3, + class_list=["Abert's Towhee", 'Acorn Woodpecker', "Allen's Hummingbird", 'American Avocet', 'American Barn Owl', 'American Bittern', 'American Bullfrog', 'American Bushtit', 'American Cliff Swallow', 'American Coot', 'American Crow', 'American Dusky Flycatcher', 'American Goldfinch', 'American Grey Flycatcher', 'American Herring Gull', 'American Kestrel', 'American Redstart', 'American Robin', 'American Wigeon', 'American Yellow Warbler', "Anna's Hummingbird", 'Ash-throated Flycatcher', "Audubon's Warbler", 'Band-tailed Pigeon', 'Barn Swallow', 'Bay-breasted Warbler', "Bell's Sparrow", "Bell's Vireo", 'Belted Kingfisher', "Bewick's Wren", 'Black Phoebe', 'Black Skimmer', 'Black Turnstone', 'Black-chinned Hummingbird', 'Black-chinned Sparrow', 'Black-crowned Night Heron', 'Black-headed Grosbeak', 'Black-hooded Oriole', 'Black-necked Grebe', 'Black-necked Stilt', 'Black-tailed Gnatcatcher', 'Black-throated Grey Warbler', 'Black-throated Magpie-Jay', 'Black-throated Sparrow', 'Blue Grosbeak', 'Blue-crowned Parakeet', 'Blue-grey Gnatcatcher', "Bonaparte's Gull", "Brandt's Cormorant", 'Brant Goose', "Brewer's Blackbird", "Brewer's Sparrow", 'Brown Creeper', 'Brown-headed Cowbird', 'Buff-bellied Pipit', "Bullock's Oriole", 'Burrowing Owl', 'Burrowing Parrot', 'Cactus Wren', 'California Gnatcatcher', 'California Ground Squirrel', 'California Gull', 'California Quail', 'California Scrub Jay', 'California Thrasher', 'California Towhee', 'Canada Goose', 'Canada Warbler', 'Canyon Bat', 'Canyon Wren', 'Caspian Tern', "Cassin's Finch", "Cassin's Kingbird", "Cassin's Vireo", 'Cedar Waxwing', 'Chestnut-collared Longspur', 'Chipping Sparrow', 'Cinnamon Teal', 'Cinnamon-rumped Seedeater', "Clark's Grebe", "Clark's Nutcracker", 'Clay-colored Sparrow', 'Cockatiel', 'Common Gallinule', 'Common Ground Dove', 'Common Poorwill', 'Common Starling', 'Common Yellowthroat', "Cooper's Hawk", "Costa's Hummingbird", 'Coyote', "Craveri's Murrelet", 'Crissal Thrasher', 'Dark-eyed Junco', 'Double-crested Cormorant', 'Downy Woodpecker', 'Dunlin', 'Dusky-capped Flycatcher', 'Eastern Subalpine Warbler', 'Elegant Tern', 'Eurasian Collared Dove', 'Evening Grosbeak', "Forster's Tern", 'Gadwall', "Gambel's Quail", 'Gila Woodpecker', 'Glaucous-blue Grosbeak', 'Glaucous-winged Gull', 'Golden-crowned Kinglet', 'Golden-crowned Sparrow', "Grace's Warbler", 'Grasshopper Sparrow', 'Great Blue Heron', 'Great Egret', 'Great Horned Owl', 'Great-tailed Grackle', 'Greater Pewee', 'Greater Roadrunner', 'Greater Yellowlegs', 'Green-tailed Towhee', 'Green-winged Teal', 'Grey Catbird', 'Grey Plover', 'Grey Vireo', 'Grey-hooded Warbler', 'Gull-billed Tern', 'Hairy Woodpecker', "Hammond's Flycatcher", "Heermann's Gull", 'Hermit Thrush', 'Hermit Warbler', 'Hooded Oriole', 'Hooded Warbler', 'Horned Lark', 'House Finch', 'House Sparrow', 'House Wren', 'Hudsonian Whimbrel', "Hutton's Vireo", 'Identity unknown', 'Inca Dove', 'Indian House Cricket', 'Killdeer', 'Lapland Longspur', 'Lark Sparrow', 'Laughing Gull', "Lawrence's Goldfinch", 'Lazuli Bunting', "LeConte's Thrasher", 'Least Bittern', 'Least Sandpiper', 'Least Tern', 'Lesser Goldfinch', 'Lesser Nighthawk', 'Lilac-crowned Amazon', "Lincoln's Sparrow", 'Little Blue Heron', 'Loggerhead Shrike', 'Long-billed Curlew', 'Long-billed Dowitcher', 'Long-chirp field cricket', 'Louisiana Waterthrush', "Lucy's Warbler", "MacGillivray's Warbler", 'Mallard', 'Marbled Godwit', 'Marsh Wren', "Merriam's Chipmunk", 'Mexican Whip-poor-will', 'Mountain Bluebird', 'Mountain Chickadee', 'Mountain Quail', 'Mourning Dove', 'Nashville Warbler', "Nelson's Sparrow", 'Northern Flicker', 'Northern Harrier', 'Northern Mockingbird', 'Northern Parula', 'Northern Pintail', 'Northern Raven', 'Northern Rough-winged Swallow', 'Northern Saw-whet Owl', 'Northern Shoveler', 'Northern Waterthrush', "Nuttall's Woodpecker", 'Oak Titmouse', 'Olive-sided Flycatcher', 'Orange-crowned Warbler', 'Pacific Golden Plover', 'Pacific Treefrog', 'Pacific Wren', 'Pacific-slope Flycatcher', 'Palm Warbler', 'Pelagic Cormorant', 'Peregrine Falcon', 'Phainopepla', 'Pied-billed Grebe', 'Pin-tailed Whydah', 'Pine Siskin', 'Pine Warbler', 'Pinyon Jay', 'Plumbeous Vireo', 'Prairie Warbler', 'Purple Finch', 'Purple Martin', 'Pygmy Nuthatch', 'Red Crossbill', 'Red-breasted Nuthatch', 'Red-crowned Amazon', 'Red-crowned Crane', 'Red-eyed Vireo', 'Red-faced Warbler', 'Red-masked Parakeet', 'Red-naped Sapsucker', 'Red-necked Grebe', 'Red-necked Phalarope', 'Red-shouldered Hawk', 'Red-tailed Hawk', 'Red-throated Pipit', 'Red-winged Blackbird', 'Redhead', "Ridgway's Rail", 'Ring-billed Gull', 'Rock Dove', 'Rock Wren', 'Rose-breasted Grosbeak', "Ross's Goose", 'Round-tailed Ground Squirrel', 'Royal Tern', 'Ruby-crowned Kinglet', 'Ruddy Duck', 'Rufous Hummingbird', 'Rufous-crowned Sparrow', 'Rusty Blackbird', 'Sage Thrasher', 'Sanderling', 'Sandhill Crane', 'Savannah Sparrow', "Say's Phoebe", 'Scaly-breasted Munia', "Scott's Oriole", 'Sharp-shinned Hawk', 'Short-billed Dowitcher', 'Slate-colored Fox Sparrow', 'Snow Goose', 'Snow Mountain Quail', 'Snowy Egret', 'Snowy Plover', 'Solitary Sandpiper', 'Song Sparrow', 'Sooty Fox Sparrow', 'Sora', 'Soundscape', 'Spotted Sandpiper', 'Spotted Towhee', "Steller's Jay", 'Stilt Sandpiper', 'Summer Tanager', 'Surf Scoter', 'Surfbird', "Swainson's Thrush", "Swinhoe's White-eye", 'Tennessee Warbler', 'Thick-billed Fox Sparrow', 'Thick-billed Kingbird', 'Thick-billed Longspur', "Townsend's Solitaire", "Townsend's Warbler", 'Tree Swallow', 'Tricolored Blackbird', 'Tropical Kingbird', 'Two-barred Crossbill', 'Verdin', 'Vermilion Flycatcher', 'Violet-green Swallow', 'Virginia Rail', 'Vocal field cricket', 'Wandering Tattler', 'Warbling Vireo', 'Western Bluebird', 'Western Cattle Egret', 'Western Grebe', 'Western Gull', 'Western Kingbird', 'Western Meadowlark', 'Western Osprey', 'Western Sandpiper', 'Western Screech Owl', 'Western Subalpine Warbler', 'Western Tanager', 'Western Wood Pewee', 'White-breasted Nuthatch', 'White-crowned Sparrow', 'White-eyed Vireo', 'White-faced Ibis', 'White-tailed Kite', 'White-throated Sparrow', 'White-throated Swift', 'White-winged Dove', 'Wild Turkey', 'Willet', 'Willow Flycatcher', "Wilson's Snipe", "Wilson's Warbler", 'Wood Duck', 'Wrentit', 'Yellow-breasted Chat', 'Yellow-crowned Night Heron', 'Yellow-footed Gull'] ) # ds = buowset_extractor( @@ -63,18 +66,25 @@ def test( # ) # Create the model - model = TimmModel.from_pretrained(model_name) + + model = HFModel.from_pretrained(model_name).cuda() - preprocessor = MelModelInputPreprocessor( - TimmInputs, duration=3 + # %% + # %% + + input_wrapper = HFInput() + + train_preprocessor = WaveformInputPreprocessor( + input_wrapper, duration=3 ) - ds["train"].set_transform(preprocessor) + + ds["train"].set_transform(train_preprocessor) # ds["valid"].set_transform(preprocessor) # ds["test"].set_transform(preprocessor) model_name = "efficientnet_b1" - run_name = f"buowset1.1_{model_name}_ATTEMPT_TO_STUDY_NEW_DATA" + run_name = f"fewshot_test_birdmae_11_12_2025_14_checkpoint-137500" # trainer = WhootTrainer._load_from_checkpoint(model_name) # Run training @@ -105,12 +115,18 @@ def test( # print(ds["train"].shape, ds["test"].shape, ds["valid"].shape) # input() - out = trainer.predict(ds["train"], metric_key_prefix="train") - print(out) + out = trainer.predict(ds["train"]) + # Pipeline requires a labels col + # For inferance the "labels" are just an array of zeros + # Therefore during inferance, "labels" are meaningless + # Delete them to make it clearer to downstream users + del out['labels'] + with open(run_name + ".pkl", mode="wb") as f: pickle.dump(out, f) - # trainer.evaluate(ds["test"], metric_key_prefix="test") - # trainer.evaluate(ds["valid"], metric_key_prefix="valid") + # Below was tested with the pickle made from above + ds = datasets.Dataset.from_dict(out) + ds.save_to_disk(f"predictions/{run_name}") # saves as a directory if __name__ == "__main__": @@ -120,7 +136,7 @@ def test( "--model_name", required=False, help="path to weights or hugging face repo id", - default="/home/sean/whoot/checkpoint-4985") + default="/home/sean/whoot/model_checkpoints/fewshot_test_birdmae_11_12_2025_14:04:09/checkpoint-137500") args = parser.parse_args() _config = parse_config(args.config) diff --git a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py index b9c560b..2af9caa 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py @@ -213,6 +213,9 @@ def get_array_chunks_from_memory( except IOError as e: print(e, file_path, "failed stat read", "continuing") continue + except EOFError as e: + print(e, file_path, "failed stat read, reached end of file", "continuing") + continue for i in tqdm( range(0, int(floor(clip_length)), chunk_length_sec), leave=False, diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py index 03f1203..75e644f 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -81,13 +81,19 @@ def __call__(self, batch): for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] try: - #TEMP FIX TO PREVENT LONG AUDIO RECORDINGS BEING BAD!!! - if librosa.get_duration(path=batch["audio"][item_idx]["path"]) > 2 * 60: - continue - y, sr = librosa.load( - path=batch["audio"][item_idx]["path"], - sr=self.sr - ) + # TODO: This is a solid section of code for loading audio + # Consider turning this into a common helper function + if len(batch["audio"][item_idx]["array"]) > 10: + y = batch["audio"][item_idx]["array"] + sr = batch["audio"][item_idx]["sampling_rate"] + else: + if librosa.get_duration(path=batch["audio"][item_idx]["path"]) > 2 * 60: + continue + y, sr = librosa.load( + path=batch["audio"][item_idx]["path"], + sr=self.sr + ) + except Exception as e: print(e) print("File Likely is corrupted, moving on") diff --git a/whoot_model_training/whoot_model_training/trainer.py b/whoot_model_training/whoot_model_training/trainer.py index bfe5c39..7a1f3d2 100644 --- a/whoot_model_training/whoot_model_training/trainer.py +++ b/whoot_model_training/whoot_model_training/trainer.py @@ -14,6 +14,8 @@ import torch from tqdm import tqdm +from datasets import Audio + from pyha_analyzer import PyhaTrainingArguments from pyha_analyzer import PyhaTrainer @@ -22,6 +24,7 @@ from .models import Model + class WhootTrainingArguments(PyhaTrainingArguments): """Holds arguments use for training.""" @@ -129,14 +132,32 @@ def predict( metric_key_prefix: str = "test" Returns: test_dataset with a new col: "pred" """ + + # test_dataset = test_dataset.select(range(100)) test_dataloader = self.get_test_dataloader(test_dataset) preds = [] + count = 0 for batch in tqdm(test_dataloader): - preds.append(self.model( - self.model.input_format(**batch) - )["logits"].detach().cpu()) - - dataset = test_dataset.to_dict() + + try: + preds.append(self.model( + self.model.input_format(**batch) + )["logits"].detach().cpu()) + except Exception as e: + print(e) + break + count += 1 + # if count > 10: + # break + + + + dataset = test_dataset.with_format()#.cast_column("audio", Audio(decode=False)) + if count == len(test_dataloader): + dataset = dataset.to_dict() + else: #If we had a failure, save what data we proccessed + dataset = dataset.select(range(count * 16)).to_dict() + # dataset = dataset.to_dict() dataset["pred"] = torch.concat(preds).detach().numpy() return dataset From a722c0956a6b91833915bfd49dd51353680bdd6d Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 9 Jan 2026 11:54:24 -0800 Subject: [PATCH 13/18] Add fixes to model training BUG NOTICE: Continue broke inferance in waveform_preprocessors.py... need to rethink how addressing corrutpted audio files works in model training. Maybe replace with empty data so it doesn't break training? --- whoot_model_training/inferance.py | 2 +- .../preprocessors/waveform_preprocessors.py | 4 +-- .../whoot_model_training/trainer.py | 36 ++++++++++++------- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/whoot_model_training/inferance.py b/whoot_model_training/inferance.py index b1b51e1..5691aec 100644 --- a/whoot_model_training/inferance.py +++ b/whoot_model_training/inferance.py @@ -115,7 +115,7 @@ def test( # print(ds["train"].shape, ds["test"].shape, ds["valid"].shape) # input() - out = trainer.predict(ds["train"]) + out = trainer.predict(ds["train"], save_path=f"predictions/{run_name}") # Pipeline requires a labels col # For inferance the "labels" are just an array of zeros # Therefore during inferance, "labels" are meaningless diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py index 75e644f..9370230 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -88,7 +88,7 @@ def __call__(self, batch): sr = batch["audio"][item_idx]["sampling_rate"] else: if librosa.get_duration(path=batch["audio"][item_idx]["path"]) > 2 * 60: - continue + break y, sr = librosa.load( path=batch["audio"][item_idx]["path"], sr=self.sr @@ -97,7 +97,7 @@ def __call__(self, batch): except Exception as e: print(e) print("File Likely is corrupted, moving on") - continue + break start = np.random.uniform(0, len(y)/sr - self.duration) diff --git a/whoot_model_training/whoot_model_training/trainer.py b/whoot_model_training/whoot_model_training/trainer.py index 7a1f3d2..4e51427 100644 --- a/whoot_model_training/whoot_model_training/trainer.py +++ b/whoot_model_training/whoot_model_training/trainer.py @@ -23,6 +23,8 @@ from .dataset import AudioDataset from .models import Model +import numpy as np + class WhootTrainingArguments(PyhaTrainingArguments): @@ -120,7 +122,8 @@ def predict( self, test_dataset: AudioDataset, ignore_keys=None, - metric_key_prefix: str = "test" + metric_key_prefix: str = "test", + save_path = "", ): """Run Inferance on a given dataset. @@ -137,27 +140,34 @@ def predict( test_dataloader = self.get_test_dataloader(test_dataset) preds = [] + data_selected = [] count = 0 for batch in tqdm(test_dataloader): try: - preds.append(self.model( + pred = self.model( self.model.input_format(**batch) - )["logits"].detach().cpu()) + )["logits"].detach().cpu().half() + preds.append(pred) #The current RAM use is 99/120... To be safe I'm going to reduce bytes + data_selected.extend(range(count, count + len(pred))) + count += len(pred) except Exception as e: - print(e) - break - count += 1 - # if count > 10: + print(e, "break in batch, don't use") + count += 16 + continue + + if count % 101 == 0: # break + import datasets + dataset = test_dataset.with_format() + out = dataset.select(data_selected).to_dict() + out["pred"] = torch.concat(preds).detach().numpy() + ds = datasets.Dataset.from_dict(out) + ds.save_to_disk(save_path) # saves as a directory - dataset = test_dataset.with_format()#.cast_column("audio", Audio(decode=False)) - if count == len(test_dataloader): - dataset = dataset.to_dict() - else: #If we had a failure, save what data we proccessed - dataset = dataset.select(range(count * 16)).to_dict() - # dataset = dataset.to_dict() + dataset = test_dataset.with_format() + dataset = dataset.select(data_selected).to_dict() dataset["pred"] = torch.concat(preds).detach().numpy() return dataset From d1136d2eb3180b5c4f9315da6db244d35037da6d Mon Sep 17 00:00:00 2001 From: sean1572 Date: Tue, 13 Jan 2026 16:28:00 -0800 Subject: [PATCH 14/18] Fix bug with corrupted files --- .../preprocessors/waveform_preprocessors.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py index 9370230..592b037 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -95,9 +95,11 @@ def __call__(self, batch): ) except Exception as e: + y = np.zeros(self.sr * 5) + sr = self.sr print(e) print("File Likely is corrupted, moving on") - break + continue start = np.random.uniform(0, len(y)/sr - self.duration) From cf750d1a6c69243d85316b07256c44853506f2eb Mon Sep 17 00:00:00 2001 From: sean1572 Date: Tue, 13 Jan 2026 16:30:47 -0800 Subject: [PATCH 15/18] Outdated inference script --- whoot_model_training/inferance.py | 144 ------------------------------ 1 file changed, 144 deletions(-) delete mode 100644 whoot_model_training/inferance.py diff --git a/whoot_model_training/inferance.py b/whoot_model_training/inferance.py deleted file mode 100644 index 5691aec..0000000 --- a/whoot_model_training/inferance.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Trains a Mutliclass Model with Pytorch and Huggingface. - -This script can be used to run experiments with different -models and datasets to create any model for bioacoustic classification - -It is intended this script to be heavily modified with each experiment -(say one wants to use a different dataset, one should copy this and change the -extractor!) - -Usage: - $ python train.py /path/to/config.yml - -config.yml should contain frequently changed hyperparameters -""" -import argparse -import pickle -import datasets - -from train import parse_config, init_env - - -from whoot_model_training.preprocessors.base_preprocessor import WaveformInputPreprocessor -from whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments -from whoot_model_training.data_extractor import raw_audio_extractor -from whoot_model_training.models import TimmModel, TimmInputs -from whoot_model_training.models import HFInput, HFModel, HFModelConfig -from whoot_model_training import CometMLLoggerSupplement -from whoot_model_training.preprocessors import ( - MelModelInputPreprocessor -) - - - - -def test( - config, - model_name="", - audio_dir="/mnt/restorage/Audiomoth/Raw sound files/2024/RGCB/" -): - """Highest level logic for inferance. - - Does the following: - - Formats the dataset into an AudioDataset - - Prepares preprocessing for each audio clip - - Builds the model - - Configures and runs the trainer - - Runs evaluation - - Args: - config (dict): the config used for training. Defined in yaml file - model_name (str): path to model checkpoint to use - audio_dir (str): path to unlabeled data - """ - # Extract a new dataset - ds = raw_audio_extractor( - audio_parent_folder=audio_dir, - output_folder="data/manual_buowset", - chunk_duration=3, - class_list=["Abert's Towhee", 'Acorn Woodpecker', "Allen's Hummingbird", 'American Avocet', 'American Barn Owl', 'American Bittern', 'American Bullfrog', 'American Bushtit', 'American Cliff Swallow', 'American Coot', 'American Crow', 'American Dusky Flycatcher', 'American Goldfinch', 'American Grey Flycatcher', 'American Herring Gull', 'American Kestrel', 'American Redstart', 'American Robin', 'American Wigeon', 'American Yellow Warbler', "Anna's Hummingbird", 'Ash-throated Flycatcher', "Audubon's Warbler", 'Band-tailed Pigeon', 'Barn Swallow', 'Bay-breasted Warbler', "Bell's Sparrow", "Bell's Vireo", 'Belted Kingfisher', "Bewick's Wren", 'Black Phoebe', 'Black Skimmer', 'Black Turnstone', 'Black-chinned Hummingbird', 'Black-chinned Sparrow', 'Black-crowned Night Heron', 'Black-headed Grosbeak', 'Black-hooded Oriole', 'Black-necked Grebe', 'Black-necked Stilt', 'Black-tailed Gnatcatcher', 'Black-throated Grey Warbler', 'Black-throated Magpie-Jay', 'Black-throated Sparrow', 'Blue Grosbeak', 'Blue-crowned Parakeet', 'Blue-grey Gnatcatcher', "Bonaparte's Gull", "Brandt's Cormorant", 'Brant Goose', "Brewer's Blackbird", "Brewer's Sparrow", 'Brown Creeper', 'Brown-headed Cowbird', 'Buff-bellied Pipit', "Bullock's Oriole", 'Burrowing Owl', 'Burrowing Parrot', 'Cactus Wren', 'California Gnatcatcher', 'California Ground Squirrel', 'California Gull', 'California Quail', 'California Scrub Jay', 'California Thrasher', 'California Towhee', 'Canada Goose', 'Canada Warbler', 'Canyon Bat', 'Canyon Wren', 'Caspian Tern', "Cassin's Finch", "Cassin's Kingbird", "Cassin's Vireo", 'Cedar Waxwing', 'Chestnut-collared Longspur', 'Chipping Sparrow', 'Cinnamon Teal', 'Cinnamon-rumped Seedeater', "Clark's Grebe", "Clark's Nutcracker", 'Clay-colored Sparrow', 'Cockatiel', 'Common Gallinule', 'Common Ground Dove', 'Common Poorwill', 'Common Starling', 'Common Yellowthroat', "Cooper's Hawk", "Costa's Hummingbird", 'Coyote', "Craveri's Murrelet", 'Crissal Thrasher', 'Dark-eyed Junco', 'Double-crested Cormorant', 'Downy Woodpecker', 'Dunlin', 'Dusky-capped Flycatcher', 'Eastern Subalpine Warbler', 'Elegant Tern', 'Eurasian Collared Dove', 'Evening Grosbeak', "Forster's Tern", 'Gadwall', "Gambel's Quail", 'Gila Woodpecker', 'Glaucous-blue Grosbeak', 'Glaucous-winged Gull', 'Golden-crowned Kinglet', 'Golden-crowned Sparrow', "Grace's Warbler", 'Grasshopper Sparrow', 'Great Blue Heron', 'Great Egret', 'Great Horned Owl', 'Great-tailed Grackle', 'Greater Pewee', 'Greater Roadrunner', 'Greater Yellowlegs', 'Green-tailed Towhee', 'Green-winged Teal', 'Grey Catbird', 'Grey Plover', 'Grey Vireo', 'Grey-hooded Warbler', 'Gull-billed Tern', 'Hairy Woodpecker', "Hammond's Flycatcher", "Heermann's Gull", 'Hermit Thrush', 'Hermit Warbler', 'Hooded Oriole', 'Hooded Warbler', 'Horned Lark', 'House Finch', 'House Sparrow', 'House Wren', 'Hudsonian Whimbrel', "Hutton's Vireo", 'Identity unknown', 'Inca Dove', 'Indian House Cricket', 'Killdeer', 'Lapland Longspur', 'Lark Sparrow', 'Laughing Gull', "Lawrence's Goldfinch", 'Lazuli Bunting', "LeConte's Thrasher", 'Least Bittern', 'Least Sandpiper', 'Least Tern', 'Lesser Goldfinch', 'Lesser Nighthawk', 'Lilac-crowned Amazon', "Lincoln's Sparrow", 'Little Blue Heron', 'Loggerhead Shrike', 'Long-billed Curlew', 'Long-billed Dowitcher', 'Long-chirp field cricket', 'Louisiana Waterthrush', "Lucy's Warbler", "MacGillivray's Warbler", 'Mallard', 'Marbled Godwit', 'Marsh Wren', "Merriam's Chipmunk", 'Mexican Whip-poor-will', 'Mountain Bluebird', 'Mountain Chickadee', 'Mountain Quail', 'Mourning Dove', 'Nashville Warbler', "Nelson's Sparrow", 'Northern Flicker', 'Northern Harrier', 'Northern Mockingbird', 'Northern Parula', 'Northern Pintail', 'Northern Raven', 'Northern Rough-winged Swallow', 'Northern Saw-whet Owl', 'Northern Shoveler', 'Northern Waterthrush', "Nuttall's Woodpecker", 'Oak Titmouse', 'Olive-sided Flycatcher', 'Orange-crowned Warbler', 'Pacific Golden Plover', 'Pacific Treefrog', 'Pacific Wren', 'Pacific-slope Flycatcher', 'Palm Warbler', 'Pelagic Cormorant', 'Peregrine Falcon', 'Phainopepla', 'Pied-billed Grebe', 'Pin-tailed Whydah', 'Pine Siskin', 'Pine Warbler', 'Pinyon Jay', 'Plumbeous Vireo', 'Prairie Warbler', 'Purple Finch', 'Purple Martin', 'Pygmy Nuthatch', 'Red Crossbill', 'Red-breasted Nuthatch', 'Red-crowned Amazon', 'Red-crowned Crane', 'Red-eyed Vireo', 'Red-faced Warbler', 'Red-masked Parakeet', 'Red-naped Sapsucker', 'Red-necked Grebe', 'Red-necked Phalarope', 'Red-shouldered Hawk', 'Red-tailed Hawk', 'Red-throated Pipit', 'Red-winged Blackbird', 'Redhead', "Ridgway's Rail", 'Ring-billed Gull', 'Rock Dove', 'Rock Wren', 'Rose-breasted Grosbeak', "Ross's Goose", 'Round-tailed Ground Squirrel', 'Royal Tern', 'Ruby-crowned Kinglet', 'Ruddy Duck', 'Rufous Hummingbird', 'Rufous-crowned Sparrow', 'Rusty Blackbird', 'Sage Thrasher', 'Sanderling', 'Sandhill Crane', 'Savannah Sparrow', "Say's Phoebe", 'Scaly-breasted Munia', "Scott's Oriole", 'Sharp-shinned Hawk', 'Short-billed Dowitcher', 'Slate-colored Fox Sparrow', 'Snow Goose', 'Snow Mountain Quail', 'Snowy Egret', 'Snowy Plover', 'Solitary Sandpiper', 'Song Sparrow', 'Sooty Fox Sparrow', 'Sora', 'Soundscape', 'Spotted Sandpiper', 'Spotted Towhee', "Steller's Jay", 'Stilt Sandpiper', 'Summer Tanager', 'Surf Scoter', 'Surfbird', "Swainson's Thrush", "Swinhoe's White-eye", 'Tennessee Warbler', 'Thick-billed Fox Sparrow', 'Thick-billed Kingbird', 'Thick-billed Longspur', "Townsend's Solitaire", "Townsend's Warbler", 'Tree Swallow', 'Tricolored Blackbird', 'Tropical Kingbird', 'Two-barred Crossbill', 'Verdin', 'Vermilion Flycatcher', 'Violet-green Swallow', 'Virginia Rail', 'Vocal field cricket', 'Wandering Tattler', 'Warbling Vireo', 'Western Bluebird', 'Western Cattle Egret', 'Western Grebe', 'Western Gull', 'Western Kingbird', 'Western Meadowlark', 'Western Osprey', 'Western Sandpiper', 'Western Screech Owl', 'Western Subalpine Warbler', 'Western Tanager', 'Western Wood Pewee', 'White-breasted Nuthatch', 'White-crowned Sparrow', 'White-eyed Vireo', 'White-faced Ibis', 'White-tailed Kite', 'White-throated Sparrow', 'White-throated Swift', 'White-winged Dove', 'Wild Turkey', 'Willet', 'Willow Flycatcher', "Wilson's Snipe", "Wilson's Warbler", 'Wood Duck', 'Wrentit', 'Yellow-breasted Chat', 'Yellow-crowned Night Heron', 'Yellow-footed Gull'] - ) - - # ds = buowset_extractor( - # metadata_csv=config["metadata_csv"], - # parent_path=config["data_path"], - # output_path=config["hf_cache_path"], - # ) - - # Create the model - - model = HFModel.from_pretrained(model_name).cuda() - - # %% - # %% - - input_wrapper = HFInput() - - train_preprocessor = WaveformInputPreprocessor( - input_wrapper, duration=3 - ) - - - ds["train"].set_transform(train_preprocessor) - # ds["valid"].set_transform(preprocessor) - # ds["test"].set_transform(preprocessor) - - model_name = "efficientnet_b1" - run_name = f"fewshot_test_birdmae_11_12_2025_14_checkpoint-137500" - - # trainer = WhootTrainer._load_from_checkpoint(model_name) - # Run training - training_args = WhootTrainingArguments( - run_name=run_name, - subproject_name=config["SUBPROJECT_NAME"]+"_INFERANCE", - dataset_name=config["DATASET_NAME"], - ) - - # COMMON OPTIONAL ARGS - training_args.num_train_epochs = 5 - training_args.eval_steps = 100 - training_args.per_device_train_batch_size = 16 - training_args.per_device_eval_batch_size = 16 - training_args.dataloader_num_workers = 1 - training_args.run_name = run_name - - trainer = WhootTrainer( - model=model, - dataset=ds, - training_args=training_args, - logger=CometMLLoggerSupplement( - augmentations=None, - name=training_args.run_name - ), - ) - - # print(ds["train"].shape, ds["test"].shape, ds["valid"].shape) - # input() - - out = trainer.predict(ds["train"], save_path=f"predictions/{run_name}") - # Pipeline requires a labels col - # For inferance the "labels" are just an array of zeros - # Therefore during inferance, "labels" are meaningless - # Delete them to make it clearer to downstream users - del out['labels'] - - with open(run_name + ".pkl", mode="wb") as f: - pickle.dump(out, f) - # Below was tested with the pickle made from above - ds = datasets.Dataset.from_dict(out) - ds.save_to_disk(f"predictions/{run_name}") # saves as a directory - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Input config path") - parser.add_argument("config", type=str, help="Path to config.yml") - parser.add_argument( - "--model_name", - required=False, - help="path to weights or hugging face repo id", - default="/home/sean/whoot/model_checkpoints/fewshot_test_birdmae_11_12_2025_14:04:09/checkpoint-137500") - args = parser.parse_args() - _config = parse_config(args.config) - - init_env(_config) - test(_config, model_name=args.model_name) From 0111bec172db1fa3b2f5751f5691bd8c7ae71e00 Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 16 Jan 2026 11:00:34 -0800 Subject: [PATCH 16/18] Lint round 1 Linted most the model_trainer with flake8 and some of data downloader, realized that data downloader needs a major code clean up --- data_downloader/xc_aux_downloader.py | 87 +++++------ whoot_model_training/train.py | 4 - .../data_extractor/raw_audio_extractor.py | 5 +- .../data_extractor/xc_extractor.py | 5 +- .../models/few_shot_model.py | 135 +++++++++--------- .../spectrogram_preprocessors.py | 1 - .../preprocessors/waveform_preprocessors.py | 6 +- .../whoot_model_training/trainer.py | 20 +-- 8 files changed, 122 insertions(+), 141 deletions(-) diff --git a/data_downloader/xc_aux_downloader.py b/data_downloader/xc_aux_downloader.py index ca1fbac..af45a36 100644 --- a/data_downloader/xc_aux_downloader.py +++ b/data_downloader/xc_aux_downloader.py @@ -1,51 +1,22 @@ -# %% -from xc import XenoCantoDownloader -from dotenv import load_dotenv -import os - -# Load environment variables from the .env file -load_dotenv() - -xcd = XenoCantoDownloader(api_key=os.environ["XC_API_KEY"]) - - - -# %% -import json - -with open("data/xc_meta.json", mode="r") as f: - data = json.load(f) - -species = { recording["en"] for page in data for recording in page["recordings"] } - -# %% -len({recording["en"] for page in data for recording in page["recordings"] }) - -# %% -len(species) - -# %% -data = [] -import tqdm -for specie in tqdm.tqdm(list(species)): - data.append(xcd(query=f'en:"{specie}"')) +"""Downloads auxiliary Xeno-Canto data and audio files. -# %% -import itertools -data = list(itertools.chain.from_iterable(data)) - -# %% -with open("xc_meta_aux.json", mode="w") as f: - json.dump(data, f, indent=4) - -# %% +Relies on output from data_downloader/xc.py +Create a .env file with XC api-key +`XC_API_KEY=your_api_key_here` +Then call directly with `python xc_aux_downloader.py` +""" import requests - -# %% import shutil import os +import json from pathlib import Path from multiprocessing.pool import ThreadPool +from xc import XenoCantoDownloader +from dotenv import load_dotenv +import pandas as pd +import tqdm +import itertools + # https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests def download_file(url, local_filename, dry_run=False): @@ -84,17 +55,33 @@ def prep_download(args): pool.close() return results -results = download_files(xcd, data) -results +def main(): + # Load environment variables from the .env file + load_dotenv() -# %% -import pandas as pd -recordings = xcd.concat_recording_data(data) -df = pd.DataFrame(recordings) + xcd = XenoCantoDownloader(api_key=os.environ["XC_API_KEY"]) + + with open("data/xc_meta.json", mode="r") as f: + data = json.load(f) + + species = { recording["en"] for page in data for recording in page["recordings"] } + + # DEBUG + # len({recording["en"] for page in data for recording in page["recordings"] }) + # len(species) -df.shape + data = [] + for specie in tqdm.tqdm(list(species)): + data.append(xcd(query=f'en:"{specie}"')) -# %% + data = list(itertools.chain.from_iterable(data)) + with open("xc_meta_aux.json", mode="w") as f: + json.dump(data, f, indent=4) + results = download_files(xcd, data) + results + recordings = xcd.concat_recording_data(data) + df = pd.DataFrame(recordings) + print(df.shape) \ No newline at end of file diff --git a/whoot_model_training/train.py b/whoot_model_training/train.py index e35207f..0b5cde2 100644 --- a/whoot_model_training/train.py +++ b/whoot_model_training/train.py @@ -26,10 +26,6 @@ from whoot_model_training.preprocessors.spectrogram_preprocessors import ( SpectrogramParams, ) -from whoot_model_training.preprocessors.spectrogram_preprocessors import ( - SpectrogramParams -) - # Uncomment for use with data augmentation # from pyha_analyzer.preprocessors import MixItUp, ComposeAudioLabel # from audiomentations import ( diff --git a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py index ac42d89..83a7b64 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py @@ -227,7 +227,10 @@ def get_array_chunks_from_memory( ) continue except EOFError as e: - print(e, file_path, "failed stat read, reached end of file", "continuing") + print( + e, + file_path, + "failed stat read, reached end of file", "continuing") continue for i in tqdm( range(0, int(floor(clip_length)), chunk_length_sec), diff --git a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py index 62f665d..aee461e 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py @@ -163,9 +163,8 @@ def xc_extractor( ] ) - # Only accept less than 10 min long clips - # Longer clips seem to courrpt more easily... + # Longer clips seem to courrpt more easily... # Format is "#:##"" hence length 4 dataset = dataset.filter( lambda x: len(x["length"]) == 4 @@ -178,8 +177,6 @@ def xc_extractor( # num_proc=16 ) - - dataset = dataset.filter( lambda x: bad_file_path not in x["audio"], ) diff --git a/whoot_model_training/whoot_model_training/models/few_shot_model.py b/whoot_model_training/whoot_model_training/models/few_shot_model.py index 33a4c06..30d14cb 100644 --- a/whoot_model_training/whoot_model_training/models/few_shot_model.py +++ b/whoot_model_training/whoot_model_training/models/few_shot_model.py @@ -11,9 +11,10 @@ do processing on top of that embedding """ -from torch import nn, Tensor +# from torch import nn, Tensor # from perch_hoplite.zoo import model_configs -from .model import Model, ModelInput, ModelOutput, has_required_inputs +# from .model import Model, ModelInput, ModelOutput, has_required_inputs +from .model import ModelInput from transformers import PretrainedConfig @@ -46,29 +47,32 @@ def __init__( self["embedding"] = self.model.embed(waveform) + # Global variable fore PerchEmbeddings perch_model = None -class PerchEmbeddings(EmbeddingModel): - """Wrapper for getting embeddings from perch.""" - # Warning, was running into issues with memory here - # Early attempts recreated model - # Hoping using global var only loads it in once - if perch_model is None: - perch_model = model_configs.load_model_by_name('perch_8') +# TODO: Create Environment based loading of models +# class PerchEmbeddings(EmbeddingModel): +# """Wrapper for getting embeddings from perch.""" + +# # Warning, was running into issues with memory here +# # Early attempts recreated model +# # Hoping using global var only loads it in once +# if perch_model is None: +# perch_model = model_configs.load_model_by_name('perch_8') - model = perch_model +# model = perch_model - def embed(self, embeddings): - """Return embeddings.""" - return embeddings +# def embed(self, embeddings): +# """Return embeddings.""" +# return embeddings -class PerchEmbeddingInput(EmbeddingInput): - """Wrapper for an input into a larger model from perch.""" - model = PerchEmbeddings() - embedding_size = 1280 +# class PerchEmbeddingInput(EmbeddingInput): +# """Wrapper for an input into a larger model from perch.""" +# model = PerchEmbeddings() +# embedding_size = 1280 class FewShotModelConfig(PretrainedConfig): @@ -87,51 +91,52 @@ def __init__( super().__init__(**kwargs) -class PerchFewShotModel(Model, nn.Module): - """Perch model intergration with pytorch.""" - def __init__( - self, - config: FewShotModelConfig - ): - """Init for TimmModel. - - kwargs: - timm_model (str): name of model backbone from timms to use, - Default: "resnet34" - pretrained (bool): use a pretrained model from timms, Default: True - in_chans (int): number of channels of audio: Default: 1 - num_classes (int): number of classes in the dataset: Default 6 - loss (any): custom loss function Default: BCEWithLogitsLoss - """ - super().__init__() - - self.input_format = PerchEmbeddingInput - self.output_format = ModelOutput - - self.config = config - assert config.num_classes > 0 - - self.linear = nn.Linear( - self.input_format.embedding_size, - config.num_classes - ) - - self.loss = nn.BCEWithLogitsLoss() - - @has_required_inputs() - def forward(self, x: PerchEmbeddingInput): - """Run model over x!""" - # Use perch to create embeddings - embeddings = Tensor( - x.model.model.embed(x["waveform"].cpu()).embeddings - ).to(x["waveform"].device) - - logits = self.linear(embeddings).squeeze(1) - loss = self.loss(logits, x["labels"]) - - return ModelOutput( - logits=logits, - embeddings=embeddings, - loss=loss, - labels=x["labels"] - ) +# class PerchFewShotModel(Model, nn.Module): +# """Perch model intergration with pytorch.""" +# def __init__( +# self, +# config: FewShotModelConfig +# ): +# """Init for TimmModel. + +# kwargs: +# timm_model (str): name of model backbone from timms to use, +# Default: "resnet34" +# pretrained (bool): use a pretrained model from timms, +# Default: True +# in_chans (int): number of channels of audio: Default: 1 +# num_classes (int): number of classes in the dataset: Default 6 +# loss (any): custom loss function Default: BCEWithLogitsLoss +# """ +# super().__init__() + +# self.input_format = PerchEmbeddingInput +# self.output_format = ModelOutput + +# self.config = config +# assert config.num_classes > 0 + +# self.linear = nn.Linear( +# self.input_format.embedding_size, +# config.num_classes +# ) + +# self.loss = nn.BCEWithLogitsLoss() + +# @has_required_inputs() +# def forward(self, x: PerchEmbeddingInput): +# """Run model over x!""" +# # Use perch to create embeddings +# embeddings = Tensor( +# x.model.model.embed(x["waveform"].cpu()).embeddings +# ).to(x["waveform"].device) + +# logits = self.linear(embeddings).squeeze(1) +# loss = self.loss(logits, x["labels"]) + +# return ModelOutput( +# logits=logits, +# embeddings=embeddings, +# loss=loss, +# labels=x["labels"] +# ) diff --git a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py index e326385..dc0bad1 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py @@ -107,7 +107,6 @@ def __call__(self, batch): ) pcen_S = librosa.pcen(S * (2**31)) - mels = ( np.array( pillow_transforms( diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py index 592b037..a003a73 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -87,13 +87,15 @@ def __call__(self, batch): y = batch["audio"][item_idx]["array"] sr = batch["audio"][item_idx]["sampling_rate"] else: - if librosa.get_duration(path=batch["audio"][item_idx]["path"]) > 2 * 60: + if librosa.get_duration( + path=batch["audio"][item_idx]["path"] + ) > 2 * 60: break y, sr = librosa.load( path=batch["audio"][item_idx]["path"], sr=self.sr ) - + except Exception as e: y = np.zeros(self.sr * 5) sr = self.sr diff --git a/whoot_model_training/whoot_model_training/trainer.py b/whoot_model_training/whoot_model_training/trainer.py index 4e51427..0db9265 100644 --- a/whoot_model_training/whoot_model_training/trainer.py +++ b/whoot_model_training/whoot_model_training/trainer.py @@ -14,7 +14,7 @@ import torch from tqdm import tqdm -from datasets import Audio +import datasets from pyha_analyzer import PyhaTrainingArguments from pyha_analyzer import PyhaTrainer @@ -23,9 +23,6 @@ from .dataset import AudioDataset from .models import Model -import numpy as np - - class WhootTrainingArguments(PyhaTrainingArguments): """Holds arguments use for training.""" @@ -123,7 +120,7 @@ def predict( test_dataset: AudioDataset, ignore_keys=None, metric_key_prefix: str = "test", - save_path = "", + save_path="", ): """Run Inferance on a given dataset. @@ -143,29 +140,24 @@ def predict( data_selected = [] count = 0 for batch in tqdm(test_dataloader): - try: pred = self.model( self.model.input_format(**batch) )["logits"].detach().cpu().half() - preds.append(pred) #The current RAM use is 99/120... To be safe I'm going to reduce bytes + preds.append(pred) data_selected.extend(range(count, count + len(pred))) count += len(pred) except Exception as e: print(e, "break in batch, don't use") count += 16 continue - + if count % 101 == 0: - # break - import datasets dataset = test_dataset.with_format() out = dataset.select(data_selected).to_dict() out["pred"] = torch.concat(preds).detach().numpy() - ds = datasets.Dataset.from_dict(out) - ds.save_to_disk(save_path) # saves as a directory - - + # saves as a directory + datasets.Dataset.from_dict(out).save_to_disk(save_path) dataset = test_dataset.with_format() dataset = dataset.select(data_selected).to_dict() From 18304e7c89170a02b600de7114619451a9974577 Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 16 Jan 2026 11:52:30 -0800 Subject: [PATCH 17/18] Linited Data Downloader --- data_downloader/downloader_demo.ipynb | 36 ++-- data_downloader/get_more_species_data.ipynb | 184 -------------------- data_downloader/xc.py | 95 +++++++--- data_downloader/xc_aux_downloader.py | 66 +++++-- 4 files changed, 131 insertions(+), 250 deletions(-) delete mode 100644 data_downloader/get_more_species_data.ipynb diff --git a/data_downloader/downloader_demo.ipynb b/data_downloader/downloader_demo.ipynb index f90fa04..487dd4e 100644 --- a/data_downloader/downloader_demo.ipynb +++ b/data_downloader/downloader_demo.ipynb @@ -7,9 +7,15 @@ "metadata": {}, "outputs": [], "source": [ + "import os\n", + "import json\n", + "import requests\n", "from xc import XenoCantoDownloader\n", "from dotenv import load_dotenv\n", - "import os\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "\n", "\n", "# Load environment variables from the .env file\n", "load_dotenv()" @@ -44,9 +50,7 @@ "metadata": {}, "outputs": [], "source": [ - "data = xcd(query=\"box:32.485,-117.582,33.482,-115.228\")\n", - "d\n", - "#box:32.485,-117.582,33.482,-115.228" + "data = xcd(query=\"box:32.485,-117.582,33.482,-115.228\")" ] }, { @@ -56,8 +60,7 @@ "metadata": {}, "outputs": [], "source": [ - "import json\n", - "import requests\n", + "\n", "with open(\"xc_meta.json\", mode=\"w\") as f:\n", " json.dump(data, f, indent=4)" ] @@ -160,21 +163,10 @@ "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", "recordings = xcd.concat_recording_data(data)\n", "df = pd.DataFrame(recordings)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "14a73c0f", - "metadata": {}, - "outputs": [], - "source": [ - "!uv add --optional notebooks seaborn" - ] - }, { "cell_type": "code", "execution_count": null, @@ -182,11 +174,6 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib\n", - "import seaborn as sns\n", - "# df[\"en\"].value_counts().hist()\n", - "\n", - "\n", "sns.histplot(df[\"en\"].value_counts())" ] }, @@ -197,12 +184,11 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "\n", "plt.ylabel(\"Number of Species\")\n", "plt.xlabel(\"Number of Indivuals Per Species\")\n", "plt.title(\"Do We Have a Few-shot Learning Problem for XC in Southern California?\")\n", - "df[\"en\"].value_counts().hist()\n" + "df[\"en\"].value_counts().hist()\n", + "plt.show()" ] }, { diff --git a/data_downloader/get_more_species_data.ipynb b/data_downloader/get_more_species_data.ipynb deleted file mode 100644 index b9db482..0000000 --- a/data_downloader/get_more_species_data.ipynb +++ /dev/null @@ -1,184 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "cac5b89f", - "metadata": {}, - "outputs": [], - "source": [ - "# %%\n", - "from xc import XenoCantoDownloader\n", - "from dotenv import load_dotenv\n", - "import os\n", - "\n", - "# Load environment variables from the .env file\n", - "load_dotenv()\n", - "\n", - "xcd = XenoCantoDownloader(api_key=os.environ[\"XC_API_KEY\"])\n", - "\n", - "\n", - "import librosa\n", - "# %%\n", - "import json\n", - "\n", - "with open(\"../data/san_diego_xc_aux/xc_meta_aux.json\", mode=\"r\") as f:\n", - " data = json.load(f)\n", - " # json.dump(data, f, indent=4)\n", - "\n", - "# %%\n", - "import requests\n", - "\n", - "# %%\n", - "import shutil\n", - "import os\n", - "from pathlib import Path\n", - "from multiprocessing.pool import ThreadPool\n", - "import tqdm\n", - "\n", - "# # https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests\n", - "# def download_file(url, local_filename, dry_run=False):\n", - "# if os.path.exists(local_filename):\n", - "# try:\n", - "# librosa.load(path=local_filename)\n", - "# return local_filename\n", - "# except Exception as e:\n", - "# pass\n", - " \n", - "# try:\n", - "# with requests.get(url, stream=True) as r:\n", - "# with open(local_filename, 'wb') as f:\n", - "# if not dry_run:\n", - "# shutil.copyfileobj(r.raw, f)\n", - "# else:\n", - "# print(local_filename)\n", - "\n", - "# return local_filename\n", - "# except Exception as e:\n", - "# print(e, flush=True)\n", - "# return None\n", - "\n", - "\n", - "def prep_download(args, dry_run=False):\n", - " url = args[0]\n", - " local_filename = args[1]\n", - " if os.path.exists(local_filename):\n", - " try:\n", - " librosa.load(path=local_filename)\n", - " return local_filename\n", - " except Exception as e:\n", - " print(local_filename, e, \"bad file, remake\")\n", - "\n", - " try:\n", - " with requests.get(url, stream=True) as r:\n", - " with open(local_filename, 'wb') as f:\n", - " if not dry_run:\n", - " shutil.copyfileobj(r.raw, f)\n", - " else:\n", - " print(local_filename)\n", - "\n", - " return local_filename\n", - " except Exception as e:\n", - " print(local_filename, e, flush=True)\n", - " return None\n", - "\n", - "def download_files(xcd, data, parent_folder=\"../data/san_diego_xc_aux/xeno-canto\", workers = 2):\n", - " \n", - "\n", - " os.makedirs(parent_folder, exist_ok=True)\n", - "\n", - " if \"recordings\" in data[0]:\n", - " data = xcd.concat_recording_data(data) \n", - " download_data = [\n", - " (recording[\"file\"], Path(parent_folder) / Path(recording[\"file-name\"].replace(\"/\", \"_\")))\n", - " for recording in data\n", - " ]\n", - "\n", - " with ThreadPool(processes=1024) as pool:\n", - " print(\"Main process: Submitting tasks...\")\n", - " \n", - " # Iterate over the results to wait for all tasks to complete.\n", - " # This loop will block until all tasks are finished.\n", - " for result in tqdm.tqdm(pool.imap_unordered(prep_download, download_data), total=len(download_data)):\n", - " if result is None:\n", - " print(\"ISSUE\")\n", - " \n", - " return results\n", - "\n", - "results = download_files(xcd, data)\n", - "results\n", - "\n", - "# %%\n", - "import pandas as pd\n", - "recordings = xcd.concat_recording_data(data)\n", - "df = pd.DataFrame(recordings)\n", - "\n", - "df.shape\n", - "\n", - "# %%\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ce6bcf04", - "metadata": {}, - "outputs": [], - "source": [ - "132510 / 303" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aeee1789", - "metadata": {}, - "outputs": [], - "source": [ - "df[\"en\"].value_counts()[df[\"en\"].value_counts() < 1000].hist(bins=50)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e0863f6a", - "metadata": {}, - "outputs": [], - "source": [ - "df[\"grp\"].value_counts()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "59bad81f", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "whoot", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/data_downloader/xc.py b/data_downloader/xc.py index a11d3c9..971f454 100644 --- a/data_downloader/xc.py +++ b/data_downloader/xc.py @@ -1,37 +1,69 @@ -import requests +"""Xeno-Canto Data Metadata Downloader and Search Module.""" import os -import json import urllib.parse +import json +import requests + class XenoCantoDownloader(): + """Handler for Xeno-Canto API. + + Note: Requires an API key from env var "XC_API_KEY". + Third version of the Xeno-Canto API is used here. + """ def __init__(self, api_key=None): + """Creates the Xeno-Canto Downloader. + + Args: + api_key (str): API key for Xeno-Canto API. + If None, looks for env var "XC_API_KEY" + """ self.endpoint_url = "https://xeno-canto.org/api/3/recordings" self.api_key = os.environ["XC_API_KEY"] if api_key is None else api_key - assert self.api_key is not None, "API KEY MISSING: Put API key in Enviroment Var!" + assert self.api_key is not None, \ + "API KEY MISSING: Put API key in Environment Var!" def __call__(self, - query = None, + query=None, loc=None, - box=None): - + ): + r"""Download XC data. + + Initally, this was intended to be used to build queries + So more args were planned (hence loc). In practice, it was easier + to build queries by hand ¯\_(ツ)_/¯ + + You can pull the query you want from the url on the website if you + are manually searching for thigns there. Its the same syntax. + + Also is useful for debugging issues there + + Args: + query (str/None): Search query string see XC Search Tags + loc (str/None): Location string for search query + """ if query is None: query = self.build_query( loc=loc, - box=None, ) - + page_datas = [] page_data = self.get_page(query, page=1) page_datas.append(page_data) - + # Get rest of data! for i in range(2, page_data["numPages"] + 1): page_data = self.get_page(query, page=i) page_datas.append(page_data) return page_datas - + def concat_recording_data(self, page_datas): + """Concatinate recording data from multiple pages. + + Args: + page_datas (list): list of page data dicts + """ new_page_data = [] for page_data in page_datas: new_page_data = new_page_data + page_data["recordings"] @@ -40,31 +72,46 @@ def concat_recording_data(self, page_datas): def build_query( self, loc="San Diego, California, United States of America", - box=None, + # box=None, ): + """Builds a query string for Xeno-Canto API. + + See https://xeno-canto.org/help/search + Args: + loc (str): Location string for search query + """ search_tags = "" if loc is not None: search_tags += f"loc:\"{loc}\"+" - return search_tags[:-1] #remove last + + # Remove trailing + + return search_tags[:-1] def get_page(self, query, page=1): - res = requests.get(self.endpoint_url + "?"+ urllib.parse.urlencode({ - "query": query, - "key": self.api_key, - "page": page - })) + """Get a page of results from Xeno-Canto API. + + Args: + query (str): Search query string see XC Search Tags + page (int): Page number to retrieve + """ + res = requests.get( + self.endpoint_url + "?" + urllib.parse.urlencode({ + "query": query, + "key": self.api_key, + "page": page + }), + timeout=100 + ) if res.status_code == 200: return json.loads(res.text) - else: - {} - + + return {} # def download_files(self, data): # if type(data) == dict: - # data = self.concat_recording_data(self, data) - + # data = self.concat_recording_data(self, data) # for recording in data: # requests - + + if __name__ == "__main__": # parser = argparse.ArgumentParser( # description='Input Directory Path' @@ -73,4 +120,4 @@ def get_page(self, query, page=1): # help='Path to metadata csv') # args = parser.parse_args() xcd = XenoCantoDownloader() - print(xcd()) \ No newline at end of file + print(xcd()) diff --git a/data_downloader/xc_aux_downloader.py b/data_downloader/xc_aux_downloader.py index af45a36..cba00cc 100644 --- a/data_downloader/xc_aux_downloader.py +++ b/data_downloader/xc_aux_downloader.py @@ -5,38 +5,66 @@ `XC_API_KEY=your_api_key_here` Then call directly with `python xc_aux_downloader.py` """ -import requests + import shutil import os import json +import itertools from pathlib import Path from multiprocessing.pool import ThreadPool -from xc import XenoCantoDownloader from dotenv import load_dotenv import pandas as pd import tqdm -import itertools +import requests +from xc import XenoCantoDownloader # https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests def download_file(url, local_filename, dry_run=False): + """Download a file from a url to a local file. + + Args: + url (str): url to download file from + local_filename (str): path to local file to save to + dry_run (bool): if True, do not actually download file + Returns: + local_filename (str): path to local file or None if failed + """ if os.path.exists(local_filename): return local_filename try: - with requests.get(url, stream=True) as r: + with requests.get(url, stream=True, timeout=1000) as r: with open(local_filename, 'wb') as f: if not dry_run: shutil.copyfileobj(r.raw, f) else: - print(local_filename) + print("Pretend download of", local_filename) return local_filename except IOError as e: print(e, flush=True) return None -def download_files(xcd, data, parent_folder="data/xeno-canto_aux", workers = 4): + +def download_files( + xcd: XenoCantoDownloader, + data: list, + parent_folder: str = "data/xeno-canto_aux", + workers: int = 4 +): + """Download all the files collected by the Xeno-Canto downloader. + + Args: + xcd (XenoCantoDownloader): the Xeno-Canto downloader object + Allows for preprocessing of recording metadata + data (list): list of recording data dicts + parent_folder (str): path to folder to store audio files + workers (int): number of parallel download workers + Tune down if hitting rate limits + Returns: + results (list): list of downloaded file paths + """ def prep_download(args): url = args[0] file_path = args[1] @@ -45,30 +73,30 @@ def prep_download(args): os.makedirs(parent_folder, exist_ok=True) if "recordings" in data[0]: - data = xcd.concat_recording_data(data) + data = xcd.concat_recording_data(data) download_data = [ (recording["file"], Path(parent_folder) / Path(recording["file-name"])) for recording in data ] pool = ThreadPool(workers) - results = pool.imap_unordered(prep_download, download_data) + results = pool.imap_unordered(prep_download, download_data) pool.close() return results + def main(): + """Script to download auxiliary Xeno-Canto data and audio files.""" # Load environment variables from the .env file load_dotenv() xcd = XenoCantoDownloader(api_key=os.environ["XC_API_KEY"]) - with open("data/xc_meta.json", mode="r") as f: + with open("data/xc_meta.json", mode="r", encoding="utf-8") as f: data = json.load(f) - species = { recording["en"] for page in data for recording in page["recordings"] } - - # DEBUG - # len({recording["en"] for page in data for recording in page["recordings"] }) - # len(species) + species = { + recording["en"] for page in data for recording in page["recordings"] + } data = [] for specie in tqdm.tqdm(list(species)): @@ -76,12 +104,16 @@ def main(): data = list(itertools.chain.from_iterable(data)) - with open("xc_meta_aux.json", mode="w") as f: + with open("xc_meta_aux.json", mode="w", encoding="utf-8") as f: json.dump(data, f, indent=4) results = download_files(xcd, data) - results + print("Done downloading files, num downloaded:", len(results)) recordings = xcd.concat_recording_data(data) df = pd.DataFrame(recordings) - print(df.shape) \ No newline at end of file + print("Metadata has shape:", df.shape) + + +if __name__ == "__main__": + main() From 815e11516d691ae350d9628462038a175a0e0867 Mon Sep 17 00:00:00 2001 From: sean1572 Date: Fri, 16 Jan 2026 16:01:40 -0800 Subject: [PATCH 18/18] Linted but needs testing --- test.py | 2 +- whoot_model_training/train.py | 4 +- .../data_extractor/esc50_extractor.py | 30 +----- .../data_extractor/raw_audio_extractor.py | 11 +- .../data_extractor/utils.py | 35 ++++++ .../data_extractor/xc_extractor.py | 50 +++------ .../whoot_model_training/models/__init__.py | 2 +- .../models/few_shot_model.py | 14 ++- .../whoot_model_training/models/hf_models.py | 14 +-- .../preprocessors/base_preprocessor.py | 4 +- .../preprocessors/default_preprocessor.py | 101 ++++++++++++++++++ .../spectrogram_preprocessors.py | 61 +++-------- .../preprocessors/waveform_preprocessors.py | 66 ++---------- .../whoot_model_training/trainer.py | 23 ++-- 14 files changed, 214 insertions(+), 203 deletions(-) create mode 100644 whoot_model_training/whoot_model_training/data_extractor/utils.py create mode 100644 whoot_model_training/whoot_model_training/preprocessors/default_preprocessor.py diff --git a/test.py b/test.py index 9423285..21c7a5f 100644 --- a/test.py +++ b/test.py @@ -4,7 +4,7 @@ # %% -from whoot_model_training.whoot_model_training.preprocessors import WaveformInputPreprocessor +from import WaveformInputPreprocessor from whoot_model_training.whoot_model_training.models import HFInput, HFModel, HFModelConfig from whoot_model_training.whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments from whoot_model_training.whoot_model_training.data_extractor import xc_extractor diff --git a/whoot_model_training/train.py b/whoot_model_training/train.py index 0b5cde2..5affeb9 100644 --- a/whoot_model_training/train.py +++ b/whoot_model_training/train.py @@ -71,7 +71,7 @@ def train(config): csv_path = "/home/sean/whoot/data/san_diego_xc_aux/xc_meta_aux.json" ds = xc_extractor( - XC_dataset_json_path=csv_path, + xc_dataset_json_path=csv_path, parent_path="/home/sean/whoot/data/san_diego_xc_aux/xeno-canto" ) @@ -166,7 +166,7 @@ def train(config): ) trainer.train() - model.save_pretrained("model_checkpoints/xc_aux") + model.save_pretrained("model_checkpoints/xc_aux_testing") def init_env(config: dict): diff --git a/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py index 6fab5bd..bf557da 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py @@ -18,28 +18,14 @@ import os from dataclasses import dataclass -import numpy as np from datasets import ( load_dataset, Audio, DatasetDict, - ClassLabel, - Sequence, ) -from ..dataset import AudioDataset - - -def one_hot_encode(row: dict, classes: list): - """One hot Encodes a list of labels. - Args: - row (dict): row of data in a dataset containing a labels column - classes: a list of classes - """ - one_hot = np.zeroes(len(classes)) - one_hot[row["labels"]] = 1 - row["labels"] = np.array(one_hot, dtype=float) - return row +from .utils import convert_labeled_dataset_onehot +from ..dataset import AudioDataset @dataclass @@ -84,17 +70,7 @@ def esc50_extractor( dataset = load_dataset("csv", data_files=metadata_csv)["train"] dataset = dataset.rename_column("category", "labels") - dataset = dataset.class_encode_column("labels") - - class_list = dataset.features["labels"].names - - multilabel_class_label = Sequence(ClassLabel(names=class_list)) - - dataset = dataset.map( - lambda row: one_hot_encode(row, class_list) - ).cast_column( - "labels", multilabel_class_label - ) + dataset = convert_labeled_dataset_onehot(dataset) dataset = dataset.add_column( "audio", [ diff --git a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py index 83a7b64..39a7350 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/raw_audio_extractor.py @@ -20,11 +20,10 @@ Dataset, table ) +from datasets.features.features import _FEATURE_TYPES, FeatureType import librosa from tqdm import tqdm import pyarrow as pa -from datasets.features.features import _FEATURE_TYPES, FeatureType - from ..dataset import AudioDataset @@ -218,14 +217,6 @@ def get_array_chunks_from_memory( " | ignoring and continuing" ) continue - except EOFError as e: - print( - e, - file_path, - "hit EOF too early, likely corrupted", - "| ignoring and continuing" - ) - continue except EOFError as e: print( e, diff --git a/whoot_model_training/whoot_model_training/data_extractor/utils.py b/whoot_model_training/whoot_model_training/data_extractor/utils.py new file mode 100644 index 0000000..3e67f60 --- /dev/null +++ b/whoot_model_training/whoot_model_training/data_extractor/utils.py @@ -0,0 +1,35 @@ +"""Utility functions for data extraction and preprocessing.""" +from datasets import ( + Dataset, + ClassLabel, + Sequence, +) + +import numpy as np + + +def one_hot_encode(row: dict, classes: list): + """One hot Encodes a list of labels. + + Args: + row (dict): row of data in a dataset containing a labels column + classes: a list of classes + """ + one_hot = np.zeros(len(classes)) + one_hot[row["labels"]] = 1 + row["labels"] = np.array(one_hot, dtype=float) + return row + + +def convert_labeled_dataset_onehot(dataset: Dataset): + """Dataset with label column to one hot encoded version.""" + dataset = dataset.class_encode_column("labels") + class_list = dataset.features["labels"].names + multilabel_class_label = Sequence(ClassLabel(names=class_list)) + dataset = dataset.map( + lambda row: one_hot_encode(row, class_list) + ).cast_column( + "labels", + multilabel_class_label + ) + return dataset diff --git a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py index aee461e..737832a 100644 --- a/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py +++ b/whoot_model_training/whoot_model_training/data_extractor/xc_extractor.py @@ -5,23 +5,23 @@ import os import shutil +import json from pathlib import Path from dataclasses import dataclass from collections import Counter from pydub import AudioSegment -import json import librosa - -import numpy as np from datasets import ( Dataset, Audio, DatasetDict, ClassLabel, - Sequence, load_from_disk, ) from ..dataset import AudioDataset +from .utils import ( + convert_labeled_dataset_onehot, +) def filter_by_count(ds, col="en", threshold=10): @@ -46,22 +46,12 @@ def filter_xc_data(row: dict): # Prevents some files from taking forever librosa.load(path=file_path, duration=3) return True - except Exception as e: + except FileNotFoundError as e: + print(e, file_path) + return False + except IOError as e: print(e, file_path) return False - - -def one_hot_encode(row: dict, classes: list): - """One hot Encodes a list of labels. - - Args: - row (dict): row of data in a dataset containing a labels column - classes: a list of classes - """ - one_hot = np.zeros(len(classes)) - one_hot[row["labels"]] = 1 - row["labels"] = np.array(one_hot, dtype=float) - return row def convert_audio_to_flac(row, error_path="bad_files", col="audio"): @@ -83,14 +73,11 @@ def convert_audio_to_flac(row, error_path="bad_files", col="audio"): try: wav_audio = AudioSegment.from_file(file_path) wav_audio.export(flac_path, format="flac") - except Exception as e: + except IOError as e: if os.path.exists(file_path): os.makedirs(error_path, exist_ok=True) shutil.move(file_path, error_path) - # If quit halfway through processing, - # make sure we get rid of the bad file - # if os.path.exists(flac_path): - # os.remove(flac_path) + print( "ERROR", "move to", @@ -119,7 +106,7 @@ class XCParams(): def xc_extractor( - XC_dataset_json_path, + xc_dataset_json_path, parent_path, cache_path="data/san_diego_xc_aux/cache", params: XCParams = XCParams(), @@ -134,7 +121,7 @@ def xc_extractor( if os.path.exists(cache_path): return load_from_disk(cache_path) - with open(XC_dataset_json_path, mode="r") as f: + with open(xc_dataset_json_path, mode="r", encoding="utf-8") as f: xc_recordings_paged = json.load(f) xc_recordings = [] @@ -143,16 +130,13 @@ def xc_extractor( dataset = Dataset.from_list(xc_recordings) - dataset = dataset.add_column("labels", dataset["en"]) - dataset = dataset.class_encode_column("labels") - class_list = dataset.features["labels"].names - multilabel_class_label = Sequence(ClassLabel(names=class_list)) - dataset = dataset.map( - lambda row: one_hot_encode(row, class_list) - ).cast_column( + dataset = dataset.add_column( "labels", - multilabel_class_label + dataset["en"], + new_fingerprint="labels" ) + dataset = dataset.class_encode_column("labels") + dataset = convert_labeled_dataset_onehot(dataset) dataset = dataset.add_column( "audio", [ diff --git a/whoot_model_training/whoot_model_training/models/__init__.py b/whoot_model_training/whoot_model_training/models/__init__.py index 55aca50..76456fd 100644 --- a/whoot_model_training/whoot_model_training/models/__init__.py +++ b/whoot_model_training/whoot_model_training/models/__init__.py @@ -25,5 +25,5 @@ "ModelOutput", # "PerchEmbeddingInput", # "PerchFewShotModel", - "FewShotModelConfig" + # "FewShotModelConfig" ] diff --git a/whoot_model_training/whoot_model_training/models/few_shot_model.py b/whoot_model_training/whoot_model_training/models/few_shot_model.py index 30d14cb..1a0be28 100644 --- a/whoot_model_training/whoot_model_training/models/few_shot_model.py +++ b/whoot_model_training/whoot_model_training/models/few_shot_model.py @@ -14,8 +14,9 @@ # from torch import nn, Tensor # from perch_hoplite.zoo import model_configs # from .model import Model, ModelInput, ModelOutput, has_required_inputs -from .model import ModelInput + from transformers import PretrainedConfig +from .model import ModelInput class EmbeddingModel(): @@ -24,6 +25,10 @@ def embed(self): """Get embedding.""" raise NotImplementedError() + def get_k_neighbors(self): + """Get k nearest neighbors.""" + raise NotImplementedError() + class EmbeddingInput(ModelInput): """Wrapper for ModelInputs that are embeddings.""" @@ -45,14 +50,15 @@ def __init__( """ super().__init__(labels, waveform, spectrogram) + # I keep getting this linting error + # But there is not too many function args here + # pylint: disable=too-many-function-args self["embedding"] = self.model.embed(waveform) # Global variable fore PerchEmbeddings -perch_model = None - +PERCH_MODEL = None -# TODO: Create Environment based loading of models # class PerchEmbeddings(EmbeddingModel): # """Wrapper for getting embeddings from perch.""" diff --git a/whoot_model_training/whoot_model_training/models/hf_models.py b/whoot_model_training/whoot_model_training/models/hf_models.py index 81f7e0e..bb428c8 100644 --- a/whoot_model_training/whoot_model_training/models/hf_models.py +++ b/whoot_model_training/whoot_model_training/models/hf_models.py @@ -1,19 +1,20 @@ """Wrapper around the hugging face model api!""" -from transformers import AutoFeatureExtractor, AutoModel +from contextlib import nullcontext + +from transformers import AutoFeatureExtractor, AutoModel, PretrainedConfig from torch import nn import torch -from contextlib import nullcontext -from transformers import PretrainedConfig from .model import Model, ModelInput, ModelOutput, has_required_inputs -class HFInput(): +class HFInput(ModelInput): """Input for Hugging Face Models. Specifies TimmModels needs labels and spectrograms that are Tensors """ + def __init__(self, labels=None, spectrogram=None, @@ -30,7 +31,7 @@ def __init__(self, self.feature_extractor = AutoFeatureExtractor.from_pretrained( extractor_path, trust_remote_code=True) - # TODO MAKE HFINPUT WORK WITH ITSELF + super().__init__(labels, waveform, spectrogram) def __call__(self, labels, spectrogram=None, waveform=None): """Create some fake ModelInputs for HFModels. @@ -134,11 +135,10 @@ def forward(self, x: HFInput) -> ModelOutput: x.spectrogram.to(self.device) ).last_hidden_state logits = self.linear(embed) - loss = self.loss(logits, x.labels) return ModelOutput( logits=logits, embeddings=embed, - loss=loss, + loss=self.loss(logits, x.labels), labels=x.labels ) diff --git a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py index 7ba1614..6e049ca 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py +++ b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py @@ -16,6 +16,8 @@ from pyha_analyzer.preprocessors import PreProcessorBase +from .default_preprocessor import DefaultPreprocessor + from .spectrogram_preprocessors import ( BuowMelSpectrogramPreprocessors, SpectrogramParams, @@ -36,7 +38,7 @@ class SpectrogramModelInPreprocessors(PreProcessorBase): def __init__( self, - spec_preprocessor: PreProcessorBase, + spec_preprocessor: DefaultPreprocessor, model_input: ModelInput, ): """Wrapper to get the raw spectrogram output of spec_preprocessor. diff --git a/whoot_model_training/whoot_model_training/preprocessors/default_preprocessor.py b/whoot_model_training/whoot_model_training/preprocessors/default_preprocessor.py new file mode 100644 index 0000000..32f57bd --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/default_preprocessor.py @@ -0,0 +1,101 @@ +"""Defines a default preprocessor class. + +Now this allows for defining a set of common audio loading utilities. +""" +from dataclasses import dataclass +import librosa +import numpy as np +from pyha_analyzer.preprocessors import PreProcessorBase + + +@dataclass +class Augmentations(): + """Dataclass for the augmentations of the model. + + audio (list[dict]): per item key name of augmentation, + value is the augmentation + spectrogram (list[dict]): same idea but augmentations + applied onto spectrograms + """ + audio = None + spectrogram = None + + +class DefaultPreprocessor(PreProcessorBase): + """Default Preprocessor class.""" + def __init__(self, name, duration, sr, *args, **kwargs): + """Initializes the DefaultPreprocessor. + + Args: + name (str): name of preprocessor for logging + duration (float): max length in seconds of audio chunk + sr (int/None): sample rate to standardize audio to + """ + super().__init__(name, *args, **kwargs) + self.duration = duration + self.sr = sr + + def load_audio(self, batch, item_idx): + """Load audio from either array or path. + + Args: + batch (dict): AudioDataset batch + item_idx (int): Processing an item in batch + Returns: + y (np.ndarray): audio array loaded + sr (int): sample rate of audio + """ + try: + if len(batch["audio"][item_idx]["array"]) > 10: + y = batch["audio"][item_idx]["array"] + sr = batch["audio"][item_idx]["sampling_rate"] + else: + if librosa.get_duration( + path=batch["audio"][item_idx]["path"] + ) > 2 * 60: + raise IOError("File too long to process") + + y, sr = librosa.load( + path=batch["audio"][item_idx]["path"], + sr=self.sr + ) + + except IOError as e: + y = np.zeros(self.sr * 5) + sr = self.sr + print("File Likely is corrupted, moving on", e) + raise IOError from e + + return y, sr + + def augment_audio( + self, + y: np.ndarray, + sr: int, + start: float, + label: str, + augments: Augmentations + ): + """Placeholder for audio augmentations. + + Args: + y: audio array + sr: sample rate + label: label associated with audio + start: starting point in seconds to crop audio + augments: augmentations to apply + """ + # Handle out of bound issues + end_sr = int(start * sr) + int(sr * self.duration) + if y.shape[-1] <= end_sr: + y = np.pad(y, end_sr - y.shape[-1]) + + # Audio Based Augmentations + if augments.audio is not None: + y, label = augments.audio(y, sr, label) + + new_y = y[int(start * sr):end_sr] + if new_y.shape[-1] < int(sr * self.duration): + raise IOError("Audio too short after augmentation") + + return new_y, label diff --git a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py index dc0bad1..db62310 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py @@ -2,14 +2,13 @@ Pulled from pyha_analyzer/preprocessors/spectogram_preprocessors.py """ - from dataclasses import dataclass import librosa import numpy as np from torchvision import transforms -from pyha_analyzer.preprocessors import PreProcessorBase +from .default_preprocessor import DefaultPreprocessor, Augmentations @dataclass @@ -28,21 +27,7 @@ class SpectrogramParams: n_mels: int = 256 -@dataclass -class Augmentations: - """Dataclass for the augmentations of the model. - - audio (list[dict]): per item key name of augmentation, - value is the augmentation - spectrogram (list[dict]): same idea but augmentations - applied onto spectrograms - """ - - audio = None - spectrogram = None - - -class BuowMelSpectrogramPreprocessors(PreProcessorBase): +class BuowMelSpectrogramPreprocessors(DefaultPreprocessor): """Preprocessor for processing audio into spectrograms. Particularly for the buow dataset @@ -72,45 +57,38 @@ def __init__( self.n_mels = spectrogram_params.n_mels self.spectrogram_params = spectrogram_params - super().__init__(name="MelSpectrogramPreprocessor") + super().__init__( + name="MelSpectrogramPreprocessor", duration=duration, sr=self.sr + ) def __call__(self, batch): """Process a batch of data from an AudioDataset.""" + # pylint: disable=duplicate-code new_audio = [] new_labels = [] for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] - y, sr = ( - batch["audio"][item_idx]["array"], - batch["audio"][item_idx]["sampling_rate"], - ) + y, sr = self.load_audio(batch, item_idx) start = 0 - # Handle out of bound issues - end_sr = int(start * sr) + int(sr * self.duration) - if y.shape[-1] <= end_sr: - y = np.pad(y, end_sr - y.shape[-1]) - - # Audio Based Augmentations - if self.augments.audio is not None: - y, label = self.augments.audio(y, sr, label) + y, label = self.augment_audio(y, sr, start, label, self.augments) pillow_transforms = transforms.ToPILImage() - S = librosa.feature.melspectrogram( - y=y[int(start * sr):end_sr], + spec = librosa.feature.melspectrogram( + y=y, sr=sr, n_fft=self.n_fft, hop_length=self.hop_length, power=self.power, n_mels=self.n_mels, ) - pcen_S = librosa.pcen(S * (2**31)) + pcen_s = librosa.pcen(spec * (2**31)) mels = ( np.array( pillow_transforms( - pcen_S + pcen_s ), np.float32, )[np.newaxis, ::] @@ -158,8 +136,7 @@ class PCENMelSpectrogramPreprocessors(BuowMelSpectrogramPreprocessors): def __call__(self, batch): """Process a batch of data from an AudioDataset.""" - new_audio = [] - new_labels = [] + new_audio, new_labels = [], [] for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] y, sr = ( @@ -167,20 +144,12 @@ def __call__(self, batch): batch["audio"][item_idx]["sampling_rate"], ) start = 0 - - # Handle out of bound issues - end_sr = int(start * sr) + int(sr * self.duration) - if y.shape[-1] <= end_sr: - y = np.pad(y, end_sr - y.shape[-1]) - - # Audio Based Augmentations - if self.augments.audio is not None: - y, label = self.augments.audio(y, sr, label) + y, label = self.augment_audio(y, sr, start, label, self.augments) pillow_transforms = transforms.ToPILImage() spec = librosa.feature.melspectrogram( - y=y[int(start * sr):end_sr], + y=y, sr=sr, n_fft=self.n_fft, hop_length=self.hop_length, diff --git a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py index a003a73..4a5f5af 100644 --- a/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py +++ b/whoot_model_training/whoot_model_training/preprocessors/waveform_preprocessors.py @@ -2,13 +2,9 @@ Pulled from pyha_analyzer/preprocessors/spectogram_preprocessors.py """ -from dataclasses import dataclass -import librosa import numpy as np - -from pyha_analyzer.preprocessors import PreProcessorBase - +from .default_preprocessor import DefaultPreprocessor, Augmentations # @dataclass # class WaveformParams: @@ -25,20 +21,7 @@ # n_mels: int = 256 -@dataclass -class Augmentations(): - """Dataclass for the augmentations of the model. - - audio (list[dict]): per item key name of augmentation, - value is the augmentation - spectrogram (list[dict]): same idea but augmentations - applied onto spectrograms - """ - audio = None - spectrogram = None - - -class WaveformPreprocessors(PreProcessorBase): +class WaveformPreprocessors(DefaultPreprocessor): """Preprocessor for processing audio into spectrograms. Particularly for the buow dataset @@ -72,7 +55,10 @@ def __init__( # self.n_mels = spectrogram_params.n_mels # self.spectrogram_params = spectrogram_params - super().__init__(name="MelSpectrogramPreprocessor") + super().__init__( + name="MelSpectrogramPreprocessor", + duration=duration, + sr=self.sr) def __call__(self, batch): """Process a batch of data from an AudioDataset.""" @@ -80,45 +66,14 @@ def __call__(self, batch): new_labels = [] for item_idx in range(len(batch["audio"])): label = batch["labels"][item_idx] - try: - # TODO: This is a solid section of code for loading audio - # Consider turning this into a common helper function - if len(batch["audio"][item_idx]["array"]) > 10: - y = batch["audio"][item_idx]["array"] - sr = batch["audio"][item_idx]["sampling_rate"] - else: - if librosa.get_duration( - path=batch["audio"][item_idx]["path"] - ) > 2 * 60: - break - y, sr = librosa.load( - path=batch["audio"][item_idx]["path"], - sr=self.sr - ) - - except Exception as e: - y = np.zeros(self.sr * 5) - sr = self.sr - print(e) - print("File Likely is corrupted, moving on") - continue - - start = np.random.uniform(0, len(y)/sr - self.duration) - # Handle out of bound issues - end_sr = int(start * sr) + int(sr * self.duration) - if y.shape[-1] <= end_sr: - y = np.pad(y, end_sr - y.shape[-1]) + y, sr = self.load_audio(batch, item_idx) - # Audio Based Augmentations - if self.augments.audio is not None: - y, label = self.augments.audio(y, sr, label) + start = np.random.uniform(0, len(y)/sr - self.duration) - new_y = y[int(start * sr):end_sr] - if new_y.shape[-1] < int(sr * self.duration): - continue + y, label = self.augment_audio(y, sr, start, label, self.augments) - new_audio.append(new_y) + new_audio.append(y) new_labels.append(label) batch["audio"] = new_audio @@ -146,6 +101,5 @@ def __repr__(self): return ( f"""{self.name} Augmentations: {self.augments} - MelSpectrogram: {self.spectrogram_params} """ ) diff --git a/whoot_model_training/whoot_model_training/trainer.py b/whoot_model_training/whoot_model_training/trainer.py index 0db9265..a83179f 100644 --- a/whoot_model_training/whoot_model_training/trainer.py +++ b/whoot_model_training/whoot_model_training/trainer.py @@ -96,7 +96,7 @@ def __init__( logger (CometMLLoggerSupplement): Class that adds additional logging On top of logging done by PyhaTrainer - preprocessor (PreProcessorBase): + preprocessor (DefaultPreprocessor): Preprocessor used for formatting the data """ metrics = WhootMutliClassMetrics(dataset.get_class_labels().names) @@ -120,8 +120,7 @@ def predict( test_dataset: AudioDataset, ignore_keys=None, metric_key_prefix: str = "test", - save_path="", - ): + save_path=""): """Run Inferance on a given dataset. Allows for getting predicted outputs to label a new dataset @@ -132,7 +131,6 @@ def predict( metric_key_prefix: str = "test" Returns: test_dataset with a new col: "pred" """ - # test_dataset = test_dataset.select(range(100)) test_dataloader = self.get_test_dataloader(test_dataset) @@ -140,17 +138,12 @@ def predict( data_selected = [] count = 0 for batch in tqdm(test_dataloader): - try: - pred = self.model( - self.model.input_format(**batch) - )["logits"].detach().cpu().half() - preds.append(pred) - data_selected.extend(range(count, count + len(pred))) - count += len(pred) - except Exception as e: - print(e, "break in batch, don't use") - count += 16 - continue + pred = self.model( + self.model.input_format(**batch) + )["logits"].detach().cpu().half() + preds.append(pred) + data_selected.extend(range(count, count + len(pred))) + count += len(pred) if count % 101 == 0: dataset = test_dataset.with_format()