forked from Brankonymous/UrbanSoundClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
67 lines (50 loc) · 2.31 KB
/
main.py
File metadata and controls
67 lines (50 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import argparse
import shutil
import time
import numpy as np
import torch
from train import TrainNeuralNetwork
from test import TestNeuralNetwork
from test_custom import CustomTest
import utils.utils as utils
from utils.constants import *
import datetime
def train(config):
start = datetime.datetime.now()
for val_fold in range(1, K_FOLD+1):
print('Time from REAL beginig: ', datetime.datetime.now() - start)
print(f'--------- Validation fold {val_fold} ---------')
trainNeuralNet = TrainNeuralNetwork(config=config)
trainNeuralNet.startTrain(val_fold)
print('Training lasted for: ', datetime.datetime.now() - start)
def test(config):
testNeuralNet = TestNeuralNetwork(config=config)
for val_fold in range(1, K_FOLD+1):
testNeuralNet.startTest(val_fold, flag_show=config['show_results'])
testNeuralNet.printAccuracy()
def custom_test(config):
# For test purposes - data\custom_audio\dog_barking.mp3
customTest = CustomTest(config)
customTest.startTest()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Common params
# Izbrisali smo argparse.BooleanOptionalAction
parser.add_argument('--type', choices=[m.name for m in ModelType], type=str, help='Input TRAIN, TEST or CUSTOM_TEST for type of classification', default=ModelType.CUSTOM_TEST.name)
parser.add_argument('--model_name', choices=[m.name for m in SupportedModels], type=str, help='Neural network (model) to use', default=SupportedModels.VGG.name)
parser.add_argument('--show_results', help='Plot loss and accuracy info', default=False)
parser.add_argument('--save_results', help='Save loss and accuracy info', default=False)
parser.add_argument('--save_model', help='Save model during training', default=True)
parser.add_argument('--custom_test_path', help= 'Path for custom audio to classify', default='')
# Wrapping configuration into a dictionary
args = parser.parse_args()
config = dict()
for arg in vars(args):
config[arg] = getattr(args, arg)
if config['type'] == 'TRAIN' or config['type'] == 'TRAIN_AND_TEST':
train(config)
if config['type'] == 'TEST' or config['type'] == 'TRAIN_AND_TEST':
test(config)
if config['type'] == 'CUSTOM_TEST':
custom_test(config)