diff --git a/notebooks/leb_salt_evaluation.ipynb b/notebooks/leb_salt_evaluation.ipynb new file mode 100644 index 0000000..f92f190 --- /dev/null +++ b/notebooks/leb_salt_evaluation.ipynb @@ -0,0 +1,1628 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uDPlSEzTD3cM" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install transformers datasets\n", + "!git clone https://github.com/sunbirdai/leb.git\n", + "!pip install -r leb/requirements.txt\n", + "!pip install jiwer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FnN8NZWqQFVo" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WqCbL5LgQDbe" + }, + "outputs": [], + "source": [ + "import yaml\n", + "import os\n", + "import torch\n", + "import numpy as np\n", + "from torch.utils.data import DataLoader\n", + "from transformers import pipeline\n", + "from datasets import load_metric\n", + "import leb.dataset\n", + "import csv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "h_B_okQFQd2-" + }, + "outputs": [], + "source": [ + "yaml_config = '''\n", + "common_source: &common_source\n", + " type: speech\n", + " language: [lug,eng,ach,nyn]\n", + " preprocessing:\n", + " - set_sample_rate:\n", + " rate: 16_000\n", + "\n", + "common_target: &common_target\n", + " type: text\n", + " language: [lug,eng,ach,nyn]\n", + " preprocessing:\n", + " - lower_case\n", + " - clean_and_remove_punctuation\n", + "\n", + "test:\n", + " huggingface_load:\n", + " - path: Sunbird/salt\n", + " name: multispeaker-lug\n", + " split: test\n", + " - path: Sunbird/salt\n", + " name: multispeaker-eng\n", + " split: test\n", + " - path: Sunbird/salt\n", + " name: multispeaker-ach\n", + " split: test\n", + " - path: Sunbird/salt\n", + " name: multispeaker-nyn\n", + " split: test\n", + "\n", + " source: *common_source\n", + " target: *common_target\n", + "\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QDoeyTf0UeJA" + }, + "outputs": [], + "source": [ + "auth_token = \"xxxxx\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3632qh57RvTO" + }, + "outputs": [], + "source": [ + "with open(\"config.yaml\", \"a\") as f:\n", + " print(yaml_config, file=f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tyoQnKF4P-hG" + }, + "outputs": [], + "source": [ + "class TranscriptionEvaluator:\n", + " \"\"\"\n", + " Evaluates transcription quality over multiple languages using a SB MMS models,\n", + " measuring Word Error Rate (WER) across a test set for each language.\n", + " \"\"\"\n", + " def __init__(self, config_path):\n", + " with open(config_path) as f:\n", + " self.config = yaml.safe_load(f)\n", + " self.test_ds = leb.dataset.create(self.config['test'])\n", + " self.supported_languages = ['ach', 'lug', 'teo', 'nyn']\n", + " self.wer_metric = load_metric(\"wer\")\n", + " self.auth_token = os.environ.get(\"HF_TOKEN\")\n", + " os.environ[\"HF_TOKEN\"] = \"xxxxx\"\n", + " self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + " def transcribe_audio_batch(self, audio_files, languages):\n", + " \"\"\"\n", + " Transcribes a batch of audio files, filtering unsupported languages.\n", + " \"\"\"\n", + " transcriptions, success_flags = [], []\n", + " for language in set(languages):\n", + " if language not in self.supported_languages:\n", + " print(f\"Skipping unsupported language: {language}\")\n", + " continue\n", + " transcriptions, success_flags = self.transcribe_for_language(audio_files, languages, language)\n", + " return transcriptions, success_flags\n", + "\n", + " def transcribe_for_language(self, audio_files, languages, language):\n", + " \"\"\"\n", + " Handles transcription for a specific language.\n", + " \"\"\"\n", + " try:\n", + " pipe = self.initialize_pipeline(language)\n", + " lang_indices = [i for i, lang in enumerate(languages) if lang == language]\n", + " lang_audio_files = [audio_files[i] for i in lang_indices]\n", + " outputs = pipe(lang_audio_files)\n", + " return self.process_outputs(outputs, lang_indices), [True] * len(lang_indices)\n", + " except Exception as e:\n", + " print(f\"Error processing language {language}: {e}\")\n", + " return [], []\n", + "\n", + " def initialize_pipeline(self, language):\n", + " \"\"\"\n", + " Initializes the pipeline for a given language.\n", + " \"\"\"\n", + " model_id = \"Sunbird/sunbird-mms\"\n", + " pipe = pipeline(model=model_id, device=self.device, token=self.auth_token)\n", + " pipe.tokenizer.set_target_lang(language)\n", + " pipe.model.load_adapter(language)\n", + " return pipe\n", + "\n", + " def process_outputs(self, outputs, lang_indices):\n", + " \"\"\"\n", + " Processes the pipeline outputs into transcriptions.\n", + " \"\"\"\n", + " transcriptions = [None] * len(lang_indices)\n", + " for i, output in enumerate(outputs):\n", + " transcriptions[lang_indices[i]] = output[\"text\"]\n", + " return transcriptions\n", + "\n", + " def calculate_batch_wer(self, predictions, references):\n", + " \"\"\"\n", + " Calculates the Word Error Rate (WER) for a batch of predictions and references.\n", + " \"\"\"\n", + " return self.wer_metric.compute(predictions=predictions, references=references)\n", + "\n", + " def evaluate(self):\n", + " \"\"\"\n", + " Evaluates the WER across the test set for each language and prints the results.\n", + " \"\"\"\n", + " batch_size = 8\n", + " loader = DataLoader(self.test_ds, batch_size=batch_size, collate_fn=lambda x: x)\n", + " total_wer, total_files = 0, 0\n", + " wer_by_language = {}\n", + "\n", + " # Open a CSV file to write the transcriptions\n", + " with open('transcriptions_comparison.csv', mode='w', newline='', encoding='utf-8') as file:\n", + " writer = csv.writer(file)\n", + " writer.writerow(['Language', 'Predicted Transcription', 'True Transcription'])\n", + "\n", + " for batch in loader:\n", + " audio_files = [np.array(item['source']) for item in batch]\n", + " languages = [item['source.language'] for item in batch]\n", + " true_transcripts = [item['target'] for item in batch]\n", + "\n", + " predicted_transcripts, success_flags = self.transcribe_audio_batch(audio_files, languages)\n", + " filtered_true_transcripts = [t for t, s in zip(true_transcripts, success_flags) if s]\n", + "\n", + " if predicted_transcripts and filtered_true_transcripts:\n", + " print(predicted_transcripts)\n", + " print(filtered_true_transcripts)\n", + " batch_wer = self.calculate_batch_wer(predicted_transcripts,filtered_true_transcripts)\n", + " total_wer += batch_wer * len(filtered_true_transcripts)\n", + " total_files += len(filtered_true_transcripts)\n", + " self.update_language_wer(languages, success_flags, batch_wer, wer_by_language)\n", + "\n", + " # Write each transcription pair to the CSV, along with its language\n", + " for language, pred, true in zip(languages, predicted_transcripts, filtered_true_transcripts):\n", + " writer.writerow([language, pred, true])\n", + "\n", + " self.print_results(wer_by_language, total_wer, total_files)\n", + "\n", + " def update_language_wer(self, languages, success_flags, batch_wer, wer_by_language):\n", + " \"\"\"\n", + " Updates the WER statistics for each language based on the batch results.\n", + " \"\"\"\n", + " filtered_languages = [lang for lang, success in zip(languages, success_flags) if success]\n", + " for language in filtered_languages:\n", + " if language not in wer_by_language:\n", + " wer_by_language[language] = []\n", + " wer_by_language[language].append(batch_wer)\n", + "\n", + " def print_results(self, wer_by_language, total_wer, total_files):\n", + " \"\"\"\n", + " Prints the final WER results by language and overall.\n", + " \"\"\"\n", + " for language, wers in wer_by_language.items():\n", + " avg_wer = sum(wers) / len(wers)\n", + " print(f\"{language}: {avg_wer:.4f}\")\n", + " overall_wer = total_wer / total_files if total_files > 0 else 0\n", + " print(f\"Overall WER across the test set: {overall_wer:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "background_save": true, + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "47c3a3523d104535af689535e7ee57ab", + "cee84e5d41bf4970b34a693c1de98fdc", + "40bc7f419ea5403595c2ba0f01a345c7", + "fbd4f69b98e04585a8c8b101b90c5e30", + "f689f5cca2974b4c9ec59a0b9eabad71", + "70359f9771a94bce934f50a0e52bf0f3", + "1d2107d7376c4d7ca33e89b3d1290b55", + "fdeeb2e496f74d1baab21447796bb7b7", + "616ceb709c2341c2865a0f2b57fba86f", + "b6c3eac83cd64363bdc7eec53aada5cc", + "b012be35eeb84fc2af05a72e742307c1", + "f3b1e257b40e49ed8f41169a50514f44" + ] + }, + "id": "jOb2XOQHGzA3", + "outputId": "1c781aa9-5659-4e9e-e648-ceeb74c2f943" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/datasets/load.py:756: FutureWarning: The repository for wer contains custom code which must be executed to correctly load the metric. You can inspect the repository content at https://raw.githubusercontent.com/huggingface/datasets/2.18.0/metrics/wer/wer.py\n", + "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n", + "Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.\n", + " warnings.warn(\n", + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['ekikola kyakasooli kya kyenvu wabula langi yakyo etera okuba eya kitaka wansi', 'buli gi amabala ameru ku bikoola byakasooli galeetebwa biwuka', 'emikolo kitundu ku bulamu', 'emikolo kitundu ku bulamu', 'ekivulu kyabakazanyyirizi kyabadde kitya', 'ekivulu kya baakazannyirizi kyabadde kitya', 'kolera bulwadde obuleetebwa obukyafu', 'kolera bulwadde obuleetebwa obukyafu']\n", + "['ekikoola kya kasooli kya kyenvu wabula langi yaakyo etera okuba eya kitaka wansi', 'ebikoola bya kasooli biriiriddwa ebisaanyi', 'emikolo kitundu ku bulamu', 'emikolo kitundu ku bulamu', 'ekivvulu kya bakazanyiikirizi kyabadde kitya', 'ekivvulu kya bakazanyiikirizi kyabadde kitya', 'kolera bulwadde obuleetebwa bukyafu', 'kolera bulwadde obuleetebwa bukyafu']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['embuto ezamangu zireetedde omuwendo gwabaana abawala abawanduka mu ssomero okweyongera', 'okusaba kwokweyimirirwa ku kakalu kujja kuwulirwa kkooti ku lwokutaano', 'abantu bangi tebaddizibwa ssente ze bakozesa mu kulabikako mu kkooti', 'buli omu yali mweraliikirivu ku ngeri yokubalirira', 'amakolero agalina abakozi abasukka mu kkumi galowooza kusalako bakozi', 'kimenya mateeka omukozi okukola nga taweebwa luwummula', 'abalimi balina okuwewulwa kwebyo bye basaasaanyizaako okusobola okwongeza ku magoba gaabwe', 'ebisolo bikosebwa nnyo endwadde']\n", + "['embuto ezamangu zireetedde omuwendo gwabaana abawala abawanduka mu ssomero okweyongera', 'okusaba kwokweyimirirwa ku kakalu kujja kuwulirwa kkooti ku lwokutaano', 'abantu bangi tebaddizibwa ssente ze bakozesa mu kulabikako mu kkooti', 'buli omu yali mweraliikirivu ku ngeri yokubalirira', 'amakolero agalina abakozi abasukka mu kkumi galowooza kusalako bakozi', 'kimenya mateeka omukozi okukola nga taweebwa luwummula', 'abalimi balina okuwewulwa kwebyo bye basaasaanyizaako okusobola okwongeza ku magoba gaabwe', 'ebisolo bikosebwa nnyo endwadde']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['ebyobulimi birina okutumbulwa', 'bulijjo tuyina okuyigira mu nsobi zaffe ezaayita', 'kampuni zamafuta zeetaaga okulaba engeri ennungi ezokukwatamu kasasiro', 'buli luvannyuma lwokulonda ebivudde mu kulonda kwabeesimbyewo birangirirwa', 'bano be beesimbyewo okuva mu bitundu byeggwanga ebyenjawulo', 'abayizi abali mu bibiina ebyakamalirizo bafunye okulambikibwa mu bigezo byabwe ebyakamalirizo', 'ttiimu yayolesezza obukodyo bwomupiira obulungi mu mpaka', 'ekirwadde bunansi kyandi tulemye okutangira naye twasobola okukirwanyisa']\n", + "['ebyobulimi birina okutumbulwa', 'bulijjo tulina okuyigira mu nsobi zaffe ezaayita', 'kkampuni zamafuta zeetaaga okulaba engeri ennungi ezokukwatamu kasasiro', 'buli luvannyuma lwokulonda ebivudde mu nkulonda kwabeesimbyewo birangirirwa', 'bano be beesimbyewo okuva mu bintundu byeggwanga ebyenjawulo', 'abayizi abali mu bibiina ebyakamalirizo bafunye okulambikibwa mu bigezo byabwe ebyakamalirizo', 'ttiimu yayolesezza obukodyo bwomupiira obulungi mu mpaka', 'ekirwadde bunnansi kyanditulemye okutangira naye twasobola okukirwanysa']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['balina okukkiriza abaana okwasanguza ebirowoozo byabwe', 'okuwulira emisango gyobuliisamaanyi mu kkooti kutwala bbanga ki', 'okunyweza ebyokwerinda mu kyalo kya mugaso nnyo', 'olukalala okuwandiikibwa abalimi luyamba abakungu okumanya bantu ki abeenyigira mu bulimi', 'omuzannyo azanyira ku mabbali gekisaawe asobola okukola ensobi ezenjawulo mu muzannyo natanenyezebwa', 'ensasula embi ereetera emisango', 'mu bunnabyalo kwe kudduka emisinde emiwanvu', 'weewala okwetaba nabantu abalwadde okwewala okusaasaana kwendwadde ezimu']\n", + "['balina okukkiriza abaana okwasanguza ebirowoozo byabwe', 'okuwulira emisango gyobuliisamaanyi mu kkooti kutwala bbanga ki', 'okunyweza ebyokwerinda mu kyalo kya mugaso nnyo', 'olukalala okuwandiikibwa abalimi luyamba abakungu okumanya bantu ki abeenyigira mu bulimi', 'omuzannyi azannyira ku mabbali gekisaawe asobola okukola ensobi ezenjawulo mu muzannyo natanenyezebwa', 'ensasula embi ereetera emisango', 'mubunabyalo kwe kudduka emisinde emiwanvu', 'weewale okwetaba nabantu abalwadde okwewala okusaasaana kwendwadde ezimu']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['ekitebe kya amerika mu uganda kisiimiddwa olwobuyambi bwakyo eri ebyobulamu bwa uganda', 'waliwo ensonga ezimu ezekuusa ku nkola yamateeka gaffe', 'ebibuga ebijja biri mu kutondebwawo mu ggwanga okusobola okutumbula obuweereza', 'abantu tebajja kukkirizibwa okufuna kitundu ku ssente ze batereka okutuusa nga bawummudde ku mirimu', 'poliisi etunula mu alipoota zokuwamba abantu okwetooloola kampala', 'abantu tebamanyi kye baagala', 'bakansala ba disitulikiti basasulwa bubi ekikosa empeereza', 'abasawo beetaaga basulemu nnyumba zabakozi okusobola okukola amangu ku balwadde']\n", + "['ekitebe kya america mu uganda kisiimiddwa olwobuyambi bwakyo eri ebyobulamu bwa uganda', 'waliwo ensonga ezimu ezeekuusa ku nkola yamateeka gaffe', 'ebibuga ebiggya biri mu kutondebwawo mu ggwanga okusobola okutumbula obuweereza', 'abantu tebajja kukkirizibwa kufuna kitundu ku ssente ze batereka okutuusa nga bawummudde ku mirimu', 'poliisi etunula mu alipoota zokuwamba bantu okwetooloola kampala', 'abantu tebamanyi kye baagala', 'bakansala ba disitulikiti basasulwa bubi ekikosa empeereza', 'abasawo beetaaga basule mu nnyumba zabakozi okusobola okukola amangu ku balwadde']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['emisolo egikungaanyizibwa kuno girina okulondoolwa okulwanyisa okukozesa obubbi bwensimbi', 'katonda wamaanyi', 'okweyongera okwabanoonyiboobubudamu kwaviiriddeko okusaasaana kwobulwadde', 'beetaaga okubuulira abantu ku miganiro gye bajja okufuna', 'okubula ebyokugula nokutunda kitataaganya entambula yoobusuubuzi', 'okulonda kutera kuggwera mu bovu yo', 'abafumbo balina okugonjoola obutakwatagana bwabwe', 'yasadde aka embuzi ye mu bajjajja']\n", + "['emisolo egikungaanyizibwa kuno girina okulondoolwa okulwanisa okukozesa obubbi bwensimbi', 'katonda wa maanyi', 'okweyongera okwabanoonyiboobubudamu kwandiviirako okusaasaana kwobulwadde', 'beetaaga okuuulira abantu ku miganyulo gye bajja okufuna', 'okubula ebyokugula nokutunda kitataaganya entambula yobusuubuzi', 'okulonda kuteera kuggwera mu buvuyo', 'abafumbo balina okugonjoola obutakwatagana bwabwe', 'yasaddako embuzi ye mu ba jjaja']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['ebirowoozo ebizimba bireeta enkulaakulana mu byenfuna', 'ettaka lirina okuwandiisibwa mu mateeka mu mannya ga nnannyini ryo omutuufu', 'amagye tegalina kwetaba mu byabufuzi nokukola ebintu byonna ebyekuusa ku byobufuzi', 'okwemulugunya kwonna okukwata ku mpeereza zemmotoka ezitambuza abalwadde zirina okutwalibwa eri abakulu ba disitulikiti', 'omuwendo gwabakyala bembuto abafa nga bali mbuto nga bazaala oba nga baakamala okuzaala kweyongera', 'obwakabaka nobwami bukyaliwo', 'abakulembeze bamadiini beetaaga okwenyigira mu mirimu gyobusuubuzi okulongoosa obulamu bwabwe', 'omukyala yayanjulidde bazadde be omwami']\n", + "['ebirowoozo ebizimba bireeta enkulaakulana mu byenfuna', 'ettaka lirina okuwandiisibwa mu mateeka mu mannya ga nannyini lyo omutuufu', 'amagye tegalina kwetaba mu byabufuzi nokukola ebintu byonna ebyekuusa ku byobufuzi', 'okwemulugunya kwonna okukwata ku mpereza zemmotoka ezitambuza abalwadde zirina kutwalibwa eri abakulu ba disitulikiti', 'omuwendo gwabakyala bembuto abafa nga bali mbuto nga bazaala oba nga baakamala okuzaala kweyongera', 'obwakabaka nobwami bukyaliwo', 'abakulembeze bamadiini beetaaga okwenyigira mu mirimu gyobusuubuzi okulongoosa obulamu bwabwe', 'omukyala yayanjulidde bazadde be omwami']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['wajja kubaawo pulogulaamu za sikaala eri abo abanaakola obulungi', 'kya mugaso okukuuma nokulabirira obutonde nobulamu bwebisolo byokutale', 'omukungu aweereddwa oluwummula', 'ebikozesebwa mu masomero bijja kugabanyizibwa mu masomero ana mu bitundu byobukiika kkono', 'abanoonyibobubudamu tebalina mmere emala kubeezaawo obulamu bwabwe', 'amasomero gateekeddwa okuwa abayizi ekyokulya', 'gavumenti ejja kwongeza ku nfuluma eyamasannyalaze okuva ku bbibiro lye karuma', 'paaka ya takisi nnungi nnyo okukoleramu emirimu']\n", + "['wajja kubaawo pulogulaamu za sikaala eri abo abanaakola obulungi', 'kya mugaso okukuuma nokulabirira obutonde nobulamu bwebisolo byokuttale', 'omukungu aweereddwa oluwummula', 'ebikozesebwa mu masomero bijja kugabanyizibwa mu masomero ana mu bitundu byobukiikakkono', 'abanoonyiboobubudamu tebalina mmere emala kubeezaawo bulamu bwabwe', 'amasomero gateekeddwa okuwa abayizi ekyokulya', 'gvaumenti ejja kwongeza ku nfulumya yamasannyalaze okuva ku bbibiro lrye karuma', 'paaka ya takisi nnungi nnyo okukoleramu emirimu']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['ebisolo byomu nsiko bireeta ssente okuva ebweru', 'waliwo okwanguyirwa mu kutuuka mu bbanka amasomero nobutale', 'yasaba abantu okwongera okuwagira abalenzi okutuusa bwe bamaliriza emisomo gyabe', 'ettaka kye kimu ku byobugagga ebisingira ddala jjajjawaffe omukyala bye yatulekera', 'akakiiko ka poliisi akakwasisa empisa kaamalawo obutakkaanya', 'omunoonyereza ku musango alina kukola ki', 'ennyanja nnalubaale mpanvu kye nkana ki', 'baakubaganya ebirowoozo ku kusoomoozebwa abayizi kwe basanga mu ssomero']\n", + "['ebisolo byomu nsiko bireeta ssente okuva ebweru', 'waliwo obwanguyirwa mu kutuuka mu bbanka amasomero nobutale', 'yasaba abantu okwongera okuwagira abalenzi okutuusa lwe bamaliriza emisomo gyabwe', 'ettaka kye kimu ku byobugagga ebisingira ddala jjajja waffe omukyala bye yatulekera', 'akakiiko ka poliisi akakwasisa empisa kaamalawo obutakkaanya', 'omunoonyereza ku misango alina kukola ki', 'ennyanja nalubaale mpanvu kyenkana ki', 'baakubaganya ebirowozo ku kusoomoozebwa abayizi kwe basanga mu ssomero']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['ekidyeri kisomba emirundi esatu mu lunaku', 'omuwendo gwabantu abaagala okugula gusalawo bungi ki obulina okuleetebwa', 'ababundabunda beebazizza katonda olokubasindikira abagabi boobuyambi okubayamba', 'abasomi mu bibiina ebyakamalirizo bajja kuddayo ku ntandikwa yomwezi gwomukaaga', 'abantu baanyiiga olwabantu abamu butaagala kuzzaayo ssente', 'abantu bekitundu balina okufuna ebyetaagisa nga gavumenti bwe yabasuubiza', 'bannamateeka bomuwawaabirwa baagamba nti baagala okulaba obwenkanya nga bukolebwa', 'abantu basanyufu olwobuyambi bwemmere okuva mu gavumenti']\n", + "['ekidyeri kisomba emirundi esatu mu lunaku', 'omuwendo gwabantu abaagala okugula gusalawo bungi ki obulina okuleetebwa', 'ababundabunda bebaziza katonda olwokubasindikira abagabi bobuyambi okubayamba', 'abasomi mu bibiina ebya kamalirizo bajja kuddayo ku ntandiikwa yomwezi gwomukaaga', 'abantu baanyiiga olwabantu abamu obutaagala kuzzaayo ssente', 'abantu bekitundu balina okufuna ebyetaagisa nga gavumenti bwe yabasuubiza', 'bannamateeka bomuwawaabirwa baagamba nti baagala okulaba obwenkanya nga bukolebwa', 'abantu basanyufu olwobuyambi bwemmere okuva mu gavumenti']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['abantu basiimye ekitongole olwebyo bye kigezaako okubakolera', 'akakiiko akakubirizi kalina obukuubagano nekitongole kyamawulire ekya westnitle press association', 'abakulembeze beddiini balina okuba abeetoowaze ate nga ba mpisa', 'kirungi okulaga okusiima', 'buli lwokola ekintu ekitali kituufu bambi weetonde', 'tteekateeka ki ezimu kuziteekeddwawo okulwanyisa akawuka ka kolona', 'abatembeye abeewandiisa beemulugunya ku bukyafu obuli mu katale', 'tulinayo mulwadde yenna leero']\n", + "['abantu basiimye ekitongole olwebyo bye kigezaako okubakolera', 'akakiiko akakubirizi kayina obukuubagano nekitongole kyamawulire ekya west nile press association', 'abakulembeze beddiini balina okuba abeetoowaze ate nga ba mpisa', 'kirungi okulaga okusiima', 'buli lwokola ekintu ekitali kituufu bambi weetonde', 'nteekateeka ki ezimu ku ziteekeddwawo okulwanyisa akawuka ka kolona', 'abatembeeyi abeewandiisa beemulugunya ku bukyafu obuli mu katale', 'tuyinayo omulwadde yenna leero']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['akawuka ka kolona kasaasaana mangu', 'obuli bwenguzi kye ki', 'kakensa aya kolokose empapula ze ezobuyigirize', 'tekinologiya aggulawo enzigi zemikisa empya okukola bizinensi', 'tulindiridde nobugumiikiriza okuteekayo ebitaala byo ku nguudo', 'abalimi abasinga balima ebirime nga bya kutunda', 'omukulembeze asiimye ttiimu yeggwanga eyomupiira gwebigere olwokukola obulungi', 'abawagizi bomupiira baagala ttiimu ewangula']\n", + "['akawuka ka kolona kasaasaana mangu', 'obuli bwenguzi kye ki', 'kakensa yakolokose empapula ze ezobuyigirize', 'tekinologiya aggulawo enzigi zemikisa empya okukola bizinensi', 'tulindiridde nobugumiikiriza okuteekayo ebitala byoku nguudo', 'abalimi abasinga balima ebirime nga bya kutunda', 'abakulembeze asiimye ttiimu yeggwanga eyomupiira gwebigere olwokukola obulungi', 'abawagizi bomupiira baagala ttiimu ewangula']\n", + "Skipping unsupported language: eng\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['enzige bwe zirumba ebirime zoonona ebyobulimi byonna']\n", + "['enzige bwe zirumba ebirime zoonoona ebyobulimi byonna']\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n", + "Skipping unsupported language: eng\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "47c3a3523d104535af689535e7ee57ab", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "adapter.ach.safetensors: 0%| | 0.00/8.81M [00:00:11: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n", + " self.wer_metric = load_metric(\"wer\")\n", + "/usr/local/lib/python3.10/dist-packages/datasets/load.py:756: FutureWarning: The repository for wer contains custom code which must be executed to correctly load the metric. You can inspect the repository content at https://raw.githubusercontent.com/huggingface/datasets/2.18.0/metrics/wer/wer.py\n", + "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n", + "Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n", + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping unsupported language: eng\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Skipping unsupported language: eng\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error processing language ach: list assignment index out of range\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error processing language nyn: list assignment index out of range\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "Some weights of the model checkpoint at Sunbird/sunbird-mms were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Sunbird/sunbird-mms and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lug: 0.1375\n", + "ach: 0.2974\n", + "nyn: 0.3918\n", + "Overall WER across the test set: 0.2530\n" + ] + } + ], + "source": [ + "# evaluator = TranscriptionEvaluator(\"config.yaml\")\n", + "# evaluator.evaluate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 162 + }, + "id": "RdRi8-YIQn7X", + "outputId": "5ba45092-9f8a-431a-9f94-51562de611e6" + }, + "outputs": [ + { + "ename": "TypeError", + "evalue": "TranscriptionEvaluator.print_results() missing 3 required positional arguments: 'wer_by_language', 'total_wer', and 'total_files'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mevaluator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_results\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m: TranscriptionEvaluator.print_results() missing 3 required positional arguments: 'wer_by_language', 'total_wer', and 'total_files'" + ] + } + ], + "source": [ + "evaluator.print_results()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "1d2107d7376c4d7ca33e89b3d1290b55": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "40bc7f419ea5403595c2ba0f01a345c7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdeeb2e496f74d1baab21447796bb7b7", + "max": 8808780, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_616ceb709c2341c2865a0f2b57fba86f", + "value": 8808780 + } + }, + "47c3a3523d104535af689535e7ee57ab": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_cee84e5d41bf4970b34a693c1de98fdc", + "IPY_MODEL_40bc7f419ea5403595c2ba0f01a345c7", + "IPY_MODEL_fbd4f69b98e04585a8c8b101b90c5e30" + ], + "layout": "IPY_MODEL_f689f5cca2974b4c9ec59a0b9eabad71" + } + }, + "616ceb709c2341c2865a0f2b57fba86f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "70359f9771a94bce934f50a0e52bf0f3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b012be35eeb84fc2af05a72e742307c1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b6c3eac83cd64363bdc7eec53aada5cc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cee84e5d41bf4970b34a693c1de98fdc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_70359f9771a94bce934f50a0e52bf0f3", + "placeholder": "​", + "style": "IPY_MODEL_1d2107d7376c4d7ca33e89b3d1290b55", + "value": "adapter.ach.safetensors: 100%" + } + }, + "f689f5cca2974b4c9ec59a0b9eabad71": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fbd4f69b98e04585a8c8b101b90c5e30": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b6c3eac83cd64363bdc7eec53aada5cc", + "placeholder": "​", + "style": "IPY_MODEL_b012be35eeb84fc2af05a72e742307c1", + "value": " 8.81M/8.81M [00:00<00:00, 37.2MB/s]" + } + }, + "fdeeb2e496f74d1baab21447796bb7b7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/stt-finetune.ipynb b/notebooks/stt-finetune.ipynb new file mode 100644 index 0000000..923999d --- /dev/null +++ b/notebooks/stt-finetune.ipynb @@ -0,0 +1,873 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q1paBwQoLcYk" + }, + "outputs": [], + "source": [ + "%%capture\n", + "# Later this will just be 'pip install leb'\n", + "!git clone https://github.com/jqug/leb.git\n", + "!pip install -r leb/requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gnLH11bqcqFi" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install transformers[torch]\n", + "!pip install accelerate -U\n", + "!pip install jiwer\n", + "!pip install omegaconf\n", + "!pip install datasets\n", + "!pip install sacremoses" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nWe3IWmIY-2l" + }, + "outputs": [], + "source": [ + "from torch import nn\n", + "import torch\n", + "from transformers import (\n", + " AutoFeatureExtractor,\n", + " AutoModelForCTC,\n", + " AutoProcessor,\n", + " HfArgumentParser,\n", + " Trainer,\n", + " TrainingArguments,\n", + " Wav2Vec2CTCTokenizer,\n", + " Wav2Vec2FeatureExtractor,\n", + " Wav2Vec2ForCTC,\n", + " Wav2Vec2Processor,\n", + " is_apex_available,\n", + " set_seed,\n", + ")\n", + "from dataclasses import dataclass, field\n", + "from typing import Union, List, Dict\n", + "import string\n", + "import datasets\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lQwY1vmzP2XP" + }, + "outputs": [], + "source": [ + "import leb.dataset\n", + "import yaml" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qJoF_8fytNPn" + }, + "source": [ + "# ASR data example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uUFWQ2dIAwmT" + }, + "outputs": [], + "source": [ + "yaml_config = '''\n", + "huggingface_load:\n", + " path: Sunbird/salt\n", + " split: train\n", + " name: multispeaker-lug\n", + "source:\n", + " type: speech\n", + " language: lug\n", + " preprocessing:\n", + " - set_sample_rate:\n", + " rate: 16_000\n", + "target:\n", + " type: text\n", + " language: lug\n", + " preprocessing:\n", + " - lower_case\n", + " - remove_punctuation\n", + "\n", + "'''\n", + "\n", + "config = yaml.safe_load(yaml_config)\n", + "train_ds = leb.dataset.create(config)" + ] + }, + { + "cell_type": "code", + "source": [ + "yaml_config = '''\n", + "huggingface_load:\n", + " path: Sunbird/salt\n", + " split: dev\n", + " name: multispeaker-lug\n", + "source:\n", + " type: speech\n", + " language: lug\n", + " preprocessing:\n", + " - set_sample_rate:\n", + " rate: 16_000\n", + "target:\n", + " type: text\n", + " language: lug\n", + " preprocessing:\n", + " - lower_case\n", + " - remove_punctuation\n", + "\n", + "'''\n", + "\n", + "config = yaml.safe_load(yaml_config)\n", + "eval_ds = leb.dataset.create(config)\n", + "\n", + "leb.utils.show_dataset(eval_ds.take(5), audio_features=['source'])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 536 + }, + "id": "6yYb3FZvCQU6", + "outputId": "cf967975-9294-44d9-a9d0-17137664a4ef" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sourcetarget
0waliwo pulogulaamu nnyingi ezokweggya mu bwavu ezeetooloorera ku byobulimi nobulunzi
1bafuna obubaka ku budde bwokusimba enkozesa yebimera ennungi nebiddiria oluvannyuma lwamakungula
2enkuyege zifuuse ensonga ennene ennyo mu nnimiro eno
3ebikoola byekimera bikwatiddwa obulwadde
4ensi yonna eri mu kirwadde bbunansi
" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pxkMnzyIKyGa" + }, + "outputs": [], + "source": [ + "# Create dict for vocabulary\n", + "def extract_all_chars(batch):\n", + " all_text = \" \".join(batch[\"target\"])\n", + " vocab = list(set(all_text))\n", + " return {\"vocab\": vocab, \"all_text\": [all_text]}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1zsYt5qpNQYu" + }, + "outputs": [], + "source": [ + "vocab_dict = {}\n", + "\n", + "for item in train_ds:\n", + " result = extract_all_chars(item)\n", + " for char in result[\"vocab\"]:\n", + " vocab_dict[char] = 1\n", + "\n", + "vocab_list = list(vocab_dict.keys())\n", + "vocab_dict = {v: k for k, v in enumerate(vocab_list)}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JsxcOK8Taig8" + }, + "outputs": [], + "source": [ + "vocab_dict[\"|\"] = vocab_dict[\" \"]\n", + "vocab_dict[\"[UNK]\"] = len(vocab_dict)\n", + "vocab_dict[\"[PAD]\"] = len(vocab_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dkfDu9wKaySy" + }, + "outputs": [], + "source": [ + "target_lang = \"lug\"\n", + "new_vocab_dict = {target_lang: vocab_dict}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Xm19vvSna1S-", + "outputId": "ed457c50-c990-4a49-b072-26f53309b692" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'lug': {'k': 0,\n", + " 'b': 1,\n", + " 'e': 2,\n", + " 'l': 3,\n", + " 's': 4,\n", + " 'u': 5,\n", + " 'i': 6,\n", + " 'a': 7,\n", + " 'y': 8,\n", + " ' ': 9,\n", + " 'g': 10,\n", + " 'm': 11,\n", + " 'n': 12,\n", + " 'r': 13,\n", + " 'o': 14,\n", + " 'z': 15,\n", + " 'd': 16,\n", + " 't': 17,\n", + " 'w': 18,\n", + " 'f': 19,\n", + " 'v': 20,\n", + " 'j': 21,\n", + " 'p': 22,\n", + " 'c': 23,\n", + " 'h': 24,\n", + " 'x': 25,\n", + " '|': 9,\n", + " '[UNK]': 27,\n", + " '[PAD]': 28}}" + ] + }, + "metadata": {}, + "execution_count": 11 + } + ], + "source": [ + "new_vocab_dict" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "jRpm-LStliuG" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def prepare_dataset(batch):\n", + " # check that all files have the correct sampling rate\n", + " # print(batch)\n", + " # print(batch.keys())\n", + " # assert (\n", + " # len(set(batch[\"sampling_rate\"])) == 1\n", + " # ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n", + "\n", + " batch[\"input_values\"] = processor(\n", + " batch[\"source\"], sampling_rate=16000\n", + " ).input_values\n", + " # Setup the processor for targets\n", + " # with processor.as_target_processor():\n", + " # batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n", + " batch[\"labels\"] = processor(text=batch[\"target\"]).input_ids\n", + "\n", + " return batch" + ], + "metadata": { + "id": "XhNk5ezUli16" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "final_train_dataset = train_ds.map(\n", + " prepare_dataset,\n", + " batch_size=4,\n", + " batched=True,\n", + ")" + ], + "metadata": { + "id": "vb8_bKTnli4N" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "final_val_dataset = eval_ds.map(\n", + " prepare_dataset,\n", + " batch_size=4,\n", + " batched=True,\n", + ")" + ], + "metadata": { + "id": "wgDoufXBEa9r" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CfM5jq_1a6EK" + }, + "outputs": [], + "source": [ + "import json\n", + "with open(\"vocab.json\", \"w\") as vocab_file:\n", + " json.dump(new_vocab_dict, vocab_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nTtabOFYbFWm" + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class DataCollatorCTCWithPadding:\n", + " \"\"\"\n", + " Data collator that will dynamically pad the inputs received.\n", + " Args:\n", + " processor (:class:`~transformers.Wav2Vec2Processor`)\n", + " The processor used for proccessing the data.\n", + " padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):\n", + " Select a strategy to pad the returned sequences (according to the model's padding side and padding index)\n", + " among:\n", + " * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single\n", + " sequence if provided).\n", + " * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the\n", + " maximum acceptable input length for the model if that argument is not provided.\n", + " * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of\n", + " different lengths).\n", + " \"\"\"\n", + "\n", + " processor: Wav2Vec2Processor\n", + " padding: Union[bool, str] = True\n", + "\n", + " def __call__(\n", + " self, features: List[Dict[str, Union[List[int], torch.Tensor]]]\n", + " ) -> Dict[str, torch.Tensor]:\n", + " # split inputs and labels since they have to be of different lenghts and need\n", + " # different padding methods\n", + "\n", + "\n", + " input_features = [\n", + " {\"input_values\": feature[\"input_values\"]} for feature in features\n", + " ]\n", + " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", + "\n", + " batch = self.processor.pad(\n", + " input_features,\n", + " padding=self.padding,\n", + " return_tensors=\"pt\",\n", + " )\n", + " labels_batch = self.processor.pad(\n", + " labels=label_features,\n", + " padding=self.padding,\n", + " return_tensors=\"pt\",\n", + " )\n", + "\n", + " # replace padding with -100 to ignore loss correctly\n", + " labels = labels_batch[\"input_ids\"].masked_fill(\n", + " labels_batch.attention_mask.ne(1), -100\n", + " )\n", + "\n", + " batch[\"labels\"] = labels\n", + "\n", + " return batch" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "o5FPprkGgyIr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SUCxtD45aulU" + }, + "outputs": [], + "source": [ + "tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(\"./\", unk_token=\"[UNK]\", pad_token=\"[PAD]\", word_delimiter_token=\"|\", target_lang=target_lang)\n", + "feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)\n", + "\n", + "processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RQA4mE1qSVqu", + "outputId": "529a41da-9a87-4bab-f474-2751d30f645e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":1: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n", + " wer_metric = datasets.load_metric(\"wer\")\n", + "/usr/local/lib/python3.10/dist-packages/datasets/load.py:753: FutureWarning: The repository for wer contains custom code which must be executed to correctly load the metric. You can inspect the repository content at https://raw.githubusercontent.com/huggingface/datasets/2.17.0/metrics/wer/wer.py\n", + "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n", + "Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "wer_metric = datasets.load_metric(\"wer\")" + ] + }, + { + "cell_type": "code", + "source": [ + "data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)" + ], + "metadata": { + "id": "PW5_IaQ1g-XS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7k0QwkkKb-9M" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "def compute_metrics(pred):\n", + " pred_logits = pred.predictions\n", + " pred_ids = np.argmax(pred_logits, axis=-1)\n", + "\n", + " pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n", + "\n", + " pred_str = processor.batch_decode(pred_ids)\n", + " # we do not want to group tokens when computing the metrics\n", + " label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n", + "\n", + " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n", + " return {\"wer\": wer}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OfCtBYqEblKh", + "outputId": "4137f09d-424b-4ed9-e932-624edce2bc86" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Some weights of the model checkpoint at facebook/mms-1b-all were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']\n", + "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/mms-1b-all and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/mms-1b-all and are newly initialized because the shapes did not match:\n", + "- lm_head.bias: found shape torch.Size([154]) in the checkpoint and torch.Size([31]) in the model instantiated\n", + "- lm_head.weight: found shape torch.Size([154, 1280]) in the checkpoint and torch.Size([31, 1280]) in the model instantiated\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model = Wav2Vec2ForCTC.from_pretrained(\n", + " \"facebook/mms-1b-all\",\n", + " attention_dropout=0.0,\n", + " hidden_dropout=0.0,\n", + " feat_proj_dropout=0.0,\n", + " layerdrop=0.0,\n", + " ctc_loss_reduction=\"mean\",\n", + " pad_token_id=processor.tokenizer.pad_token_id,\n", + " vocab_size=len(processor.tokenizer),\n", + " ignore_mismatched_sizes=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Vd84XbaXb3JV" + }, + "outputs": [], + "source": [ + "model.gradient_checkpointing_enable()\n", + "model.init_adapter_layers()\n", + "model.freeze_base_model()\n", + "\n", + "adapter_weights = model._get_adapters()\n", + "for param in adapter_weights.values():\n", + " param.requires_grad = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0rBwCRvAb57u" + }, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=\"output/mms-lug\",\n", + " # group_by_length=True,\n", + " per_device_train_batch_size=2,\n", + " evaluation_strategy=\"steps\",\n", + " num_train_epochs=5,\n", + " max_steps=10000,\n", + " gradient_checkpointing=True,\n", + " fp16=True,\n", + " save_steps=1000,\n", + " eval_steps=1000,\n", + " logging_steps=1000,\n", + " learning_rate=1e-3,\n", + " warmup_steps=100,\n", + " save_total_limit=2,\n", + " # push_to_hub=True,\n", + " # report_to=\"wandb\",\n", + " run_name=\"mms-lug\",\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model=\"wer\",\n", + " greater_is_better=False,\n", + " weight_decay=0.01,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3MSyS7ZacR7j" + }, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " data_collator=data_collator,\n", + " args=training_args,\n", + " compute_metrics=compute_metrics,\n", + " train_dataset=final_train_dataset,\n", + " eval_dataset=final_val_dataset,\n", + " tokenizer=processor.feature_extractor,\n", + ")" + ] + }, + { + "cell_type": "code", + "source": [ + "trainer.train()" + ], + "metadata": { + "id": "07dSvyLcgrQS", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 784 + }, + "outputId": "d27750af-228c-4a90-d493-1edf2cd6ae42" + }, + "execution_count": null, + "outputs": [ + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [ 4484/10000 59:18 < 1:12:59, 1.26 it/s, Epoch 1.20/9223372036854775807]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation LossWer
10000.1706000.2700760.281981
20000.1769000.2498650.277854
30000.2568000.2584470.309491
40000.1152000.2366130.283356

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [ 7594/10000 1:42:24 < 32:27, 1.24 it/s, Epoch 3.01/9223372036854775807]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation LossWer
10000.1706000.2700760.281981
20000.1769000.2498650.277854
30000.2568000.2584470.309491
40000.1152000.2366130.283356
50000.2276000.2102570.270977
60000.1495000.2226110.276479
70000.1382000.2185850.277854

" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n", + " warnings.warn(\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "g5yPYYE4g6CK" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file