diff --git a/README.md b/README.md index 09be0b3..a1ab6b2 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ We also find that it is essential to ensure that the unlabeled audio has vocal p - Python Library - Keras 2.3.0 (Deep Learning library) + - or PyTorch 1.7.0 if you use --run_on_torch option - Librosa 0.7.0 (for STFT) - madmom 0.16.1 (for loading audio and resampling) - Numpy, SciPy @@ -39,6 +40,7 @@ $ python melodyExtraction_NS.py -p ./audio/test_audio_file.mp4 -o ./results/ -gp -gpu gpu_index Assign a gpu index for processing. It will run with cpu if None. (default: None) -o output_dir Path to output folder (default: ./results/) + -torch run_on_torch Run on PyTorch instead of Keras. The output result can be slightly different. ``` diff --git a/melodyExtraction_NS.py b/melodyExtraction_NS.py index 555417b..e813dd8 100644 --- a/melodyExtraction_NS.py +++ b/melodyExtraction_NS.py @@ -7,12 +7,11 @@ import sys import matplotlib.pyplot as plt import glob - -from model import * +import torch from featureExtraction import * -def melodyExtraction_NS(file_name, output_path, gpu_index): +def melodyExtraction_NS(file_name, output_path, gpu_index, run_on_torch): os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' if gpu_index is None: os.environ['CUDA_VISIBLE_DEVICES'] = '' @@ -26,11 +25,27 @@ def melodyExtraction_NS(file_name, output_path, gpu_index): ''' Features extraction''' X_test, X_spec = spec_extraction( file_name=file_name, win_size=31) - + ''' melody predict''' - model = melody_ResNet() - model.load_weights('./weights/ResNet_NS.hdf5') - y_predict = model.predict(X_test, batch_size=64, verbose=1) + + if run_on_torch: + import torch + from model_torch import Melody_ResNet as TorchModel + model = TorchModel() + model.load_state_dict(torch.load('./weights/torch_weights.pt')) + torch_input = torch.Tensor(X_test).permute(0,3,1,2) + if gpu_index is not None: + model = model.to('cuda') + torch_input = torch_input.to('cuda') + model.eval() + with torch.no_grad(): + y_predict = model(torch_input) + y_predict = y_predict.cpu().numpy() + else: + from model import melody_ResNet + model = melody_ResNet() + model.load_weights('./weights/ResNet_NS.hdf5') + y_predict = model.predict(X_test, batch_size=64, verbose=1) y_shape = y_predict.shape num_total_frame = y_shape[0]*y_shape[1] @@ -63,18 +78,20 @@ def parser(): p = argparse.ArgumentParser() p.add_argument('-p', '--filepath', help='Path to input audio (default: %(default)s', - type=str, default='test_audio_file.mp4') + type=str, default='audio/test_audio.mp4') p.add_argument('-o', '--output_dir', help='Path to output folder (default: %(default)s', type=str, default='./results/') p.add_argument('-gpu', '--gpu_index', help='Assign a gpu index for processing. It will run with cpu if None. (default: %(default)s', type=int, default=None) + p.add_argument('-torch', '--run_on_torch', + help='Run on PyTorch instead of Keras', default=False, action='store_true') return p.parse_args() if __name__ == '__main__': args = parser() melodyExtraction_NS(file_name=args.filepath, - output_path=args.output_dir, gpu_index=args.gpu_index) + output_path=args.output_dir, gpu_index=args.gpu_index, run_on_torch=args.run_on_torch) diff --git a/model_torch.py b/model_torch.py new file mode 100644 index 0000000..5486466 --- /dev/null +++ b/model_torch.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +import math + + +class ConvNorm(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size): + super(ConvNorm, self).__init__() + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=(kernel_size-1) // 2, bias=False) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.01) + self.activation = nn.LeakyReLU(0.01) + + def forward(self, x, return_shortcut=False, skip_activation=False): + shortcut = self.bn(self.conv(x)) + if return_shortcut: + return self.activation(shortcut), shortcut + elif skip_activation: + return shortcut + else: + return self.activation(shortcut) + + +class ResNet_Block(nn.Module): + def __init__(self, num_input_ch, num_channels): + super(ResNet_Block, self).__init__() + self.conv1 = ConvNorm(num_input_ch, num_channels, 1) + self.conv2 = ConvNorm(num_channels, num_channels, 3) + self.conv3 = ConvNorm(num_channels, num_channels, 3) + self.conv4 = ConvNorm(num_channels, num_channels, 1) + + def cal_conv(self,x): + return self.conv4(self.conv3(self.conv2(self.conv1(x)))) + + def forward(self, x): + x, shortcut = self.conv1(x, return_shortcut=True) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x, skip_activation=True) + + x += shortcut + x = self.conv4.activation(x) + x = torch.max_pool2d(x, (1,4)) + + return x + +class Melody_ResNet(nn.Module): + def __init__(self): + super(Melody_ResNet,self).__init__() + self.block = nn.Sequential( + ResNet_Block(1, 64), + ResNet_Block(64, 128), + ResNet_Block(128, 192), + ResNet_Block(192, 256), + ) + + # Keras uses a hard_sigmoid for default activation, but the test showed that using a plain sigmoid in PyTorch + # showed the most similar result with the pre-trained Keras Model + # Also, PyTorch LSTM does not provides recurernt_dropout. + self.lstm = nn.LSTM(512, 256, bidirectional=True, batch_first=True, dropout=0.3) + num_output = int(55 * 2 ** (math.log(8, 2)) + 2) + self.final = nn.Linear(512,num_output) + self.batch_size = 300 + + def forward(self, input): + total_output = [] + for i in range(math.ceil(input.shape[0]/self.batch_size)): + block = self.block(input[i*self.batch_size:(i+1)*self.batch_size]) # channel first for torch + numOutput_P = block.shape[1] * block.shape[3] + reshape_out = block.permute(0,2,3,1).reshape(block.shape[0], 31, numOutput_P) + + lstm_out, _ = self.lstm(reshape_out) + out = self.final(lstm_out) + out = torch.softmax(out, dim=-1) + total_output.append(out) + return torch.cat(total_output) \ No newline at end of file diff --git a/torch_weight_conversion.ipynb b/torch_weight_conversion.ipynb new file mode 100644 index 0000000..535b752 --- /dev/null +++ b/torch_weight_conversion.ipynb @@ -0,0 +1,235 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.11-final" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python361164bit3aacfbc5ce1d45a7860a65ab2586d2ee", + "display_name": "Python 3.6.11 64-bit", + "language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "from model import melody_ResNet\n", + "from model_torch import Melody_ResNet\n", + "import torch\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "keras_model = melody_ResNet()\n", + "keras_model.load_weights('./weights/ResNet_NS.hdf5')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "total_weights = []\n", + "for layer in keras_model.layers:\n", + " weights = layer.get_weights()\n", + " if weights != []:\n", + " total_weights.append({'name': layer.name, 'weights': weights })" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "34" + ] + }, + "metadata": {}, + "execution_count": 5 + } + ], + "source": [ + "len(total_weights)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[(0, 'conv_s1_1x1'), (1, 'batch_normalization_1'), (2, 'conv1_1'), (3, 'batch_normalization_2'), (4, 'conv1_2'), (5, 'batch_normalization_3'), (6, 'conv_f1_1x1'), (7, 'batch_normalization_4'), (8, 'conv_s2_1x1'), (9, 'batch_normalization_5'), (10, 'conv2_1'), (11, 'batch_normalization_6'), (12, 'conv2_2'), (13, 'batch_normalization_7'), (14, 'conv_f2_1x1'), (15, 'batch_normalization_8'), (16, 'conv_s3_1x1'), (17, 'batch_normalization_9'), (18, 'conv3_1'), (19, 'batch_normalization_10'), (20, 'conv3_2'), (21, 'batch_normalization_11'), (22, 'conv_f3_1x1'), (23, 'batch_normalization_12'), (24, 'conv_s4_1x1'), (25, 'batch_normalization_13'), (26, 'conv4_1'), (27, 'batch_normalization_14'), (28, 'conv4_2'), (29, 'batch_normalization_15'), (30, 'conv_f4_1x1'), (31, 'batch_normalization_16'), (32, 'bidirectional_1'), (33, 'time_distributed_1')]\n" + ] + } + ], + "source": [ + "print([(i, x['name']) for i,x in enumerate(total_weights)])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "torch_model = Melody_ResNet()\n", + "\n", + "for i in range(4):\n", + " block_weights = total_weights[i*8:(i+1)*8]\n", + " for j in range(4):\n", + " getattr(torch_model.block[i], f'conv{j+1}').conv.weight.data = torch.from_numpy(np.transpose(block_weights[j*2]['weights'][0], (3, 2, 0, 1)))\n", + " getattr(torch_model.block[i], f'conv{j+1}').bn.weight.data = torch.from_numpy(block_weights[j*2+1]['weights'][0])\n", + " getattr(torch_model.block[i], f'conv{j+1}').bn.bias.data = torch.from_numpy(block_weights[j*2+1]['weights'][1])\n", + " getattr(torch_model.block[i], f'conv{j+1}').bn.running_mean.data = torch.from_numpy(block_weights[j*2+1]['weights'][2])\n", + " getattr(torch_model.block[i], f'conv{j+1}').bn.running_var.data = torch.from_numpy(block_weights[j*2+1]['weights'][3])\n", + "\n", + "lstm_weight_name = ['weight_ih_l0',\n", + " 'weight_hh_l0',\n", + " 'bias_ih_l0',\n", + " 'bias_hh_l0',\n", + " 'weight_ih_l0_reverse',\n", + " 'weight_hh_l0_reverse',\n", + " 'bias_ih_l0_reverse',\n", + " 'bias_hh_l0_reverse']\n", + "keras_lstm_index = [0,1,2,2,3,4,5,5]\n", + "\n", + "for i, name in enumerate(lstm_weight_name):\n", + " if i in (0,1,4,5):\n", + " getattr(torch_model.lstm, name).data = torch.from_numpy(np.transpose(total_weights[-2]['weights'][keras_lstm_index[i]]))\n", + " else:\n", + " weight = total_weights[-2]['weights'][keras_lstm_index[i]]\n", + " getattr(torch_model.lstm, name).data = torch.from_numpy(weight) / 2\n", + "\n", + "\n", + "torch_model.final.weight.data = torch.from_numpy(np.transpose(total_weights[-1]['weights'][0]))\n", + "\n", + "torch_model.final.bias.data = torch.from_numpy(total_weights[-1]['weights'][1])\n", + "\n", + "torch.save(torch_model.state_dict(), 'torch_weights.pt')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 122, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(64, 64, 3, 3)" + ] + }, + "metadata": {}, + "execution_count": 122 + } + ], + "source": [ + "np.transpose(total_weights[4]['weights'][0]).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([64, 1, 1, 1, 1])" + ] + }, + "metadata": {}, + "execution_count": 91 + } + ], + "source": [ + "state_dict['block.0.conv1.conv.weight'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# vars(torch_model.lstm)\n", + "\n", + "# print([x.shape for x in total_weights[-2]['weights']])\n", + "# total_weights[-2]['weights'][0].shape\n", + "# lstm_weight_name" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# vars(keras_model.layers[-3])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# vars(keras_model.layers[-3].forward_layer.cell)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ] +} \ No newline at end of file diff --git a/weights/torch_weights.pt b/weights/torch_weights.pt new file mode 100644 index 0000000..003c01f Binary files /dev/null and b/weights/torch_weights.pt differ