-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbirdsong.py
More file actions
executable file
·86 lines (73 loc) · 2.64 KB
/
birdsong.py
File metadata and controls
executable file
·86 lines (73 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Script for training generative model for birdsong.
"""
import argparse
import os
import model
import utils
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Code for training birdsong generator model.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--rebuild-data',
default=False,
action='store_true',
help='If set, rebuilds the dataset.')
parser.add_argument('-m', '--rebuild-model',
default=False,
action='store_true',
help='If set, resets the model weigts.')
parser.add_argument('-r', '--plot-real',
default=False,
action='store_true',
help='If set, plot samples of real data.')
parser.add_argument('-g', '--plot-gen',
default=False,
action='store_true',
help='If set, plot samples of generated data.')
parser.add_argument('-f', '--plot-filters',
default=False,
action='store_true',
help='If set, plot the model filters.')
parser.add_argument('-G', '--plot-gif',
default=False,
action='store_true',
help='If set, make a gif of interpolating latent space.')
parser.add_argument('-t', '--time-length',
default=50,
type=int,
metavar='N',
help='Number of bins in the time axis per sample.')
parser.add_argument('-n', '--nb-epoch',
default=10,
type=int,
metavar='N',
help='Number of epochs to train.')
args = parser.parse_args()
x = utils.get_all_spectrograms(args.time_length, rebuild=args.rebuild_data)
if args.plot_real:
utils.plot_sample(x, 'Spectrograms of real data')
trained_model = model.train(x,
nb_epoch=args.nb_epoch,
rebuild=args.rebuild_model)
if args.plot_gen:
x = trained_model.sample(['normal'], num_samples=32)
utils.plot_sample(x, 'Generated spectrograms')
if args.plot_filters:
x = model.get_discriminator_filters()
utils.plot_sample(x, 'Discriminator filters',
vmin=None,
vmax=None,
width=4,
height=2)
x = model.get_generator_filters(args.time_length)
utils.plot_sample(x, 'Generator filters',
vmin=None,
vmax=None,
width=4,
height=2)
if args.plot_gif:
pts = model.interpolate_latent_space(trained_model, nb_points=60)
utils.plot_as_gif(pts)