-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
67 lines (55 loc) · 1.98 KB
/
train.py
File metadata and controls
67 lines (55 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch.optim.lr_scheduler import OneCycleLR
from wandb_writer import WanDBWriter
from config.configs_classes import TrainConfig, MelSpectrogramConfig, FastSpeechConfig
from collator.collate import get_data_to_buffer, BufferDataset, collate_fn_tensor
from torch.utils.data import Dataset, DataLoader
from loss.loss import FastSpeechLoss
from model.fastspeech import FastSpeech
import os
from train_loop.train_function import train_loop
# train_config = TrainConfig()
mel_config = MelSpectrogramConfig()
model_config = FastSpeechConfig()
train_config = TrainConfig()
buffer = get_data_to_buffer(train_config)
dataset = BufferDataset(buffer)
one_batch_dataset = dataset[:256]
print(len(one_batch_dataset))
train_dataset = dataset[:int(len(dataset) * 0.8)]
val_dataset = dataset[int(len(dataset) * 0.8):]
training_loader = DataLoader(
train_dataset,
batch_size=train_config.batch_expand_size * train_config.batch_size,
shuffle=True,
collate_fn=collate_fn_tensor,
drop_last=True,
num_workers=0
)
val_loader = DataLoader(
val_dataset,
batch_size=train_config.batch_expand_size * train_config.batch_size,
shuffle=True,
collate_fn=collate_fn_tensor,
drop_last=True,
num_workers=0
)
model = FastSpeech(model_config, mel_config)
model = model.to(train_config.device)
fastspeech_loss = FastSpeechLoss()
current_step = 0
optimizer = torch.optim.AdamW(
model.parameters(),
lr=train_config.learning_rate,
betas=(0.9, 0.98),
eps=1e-9)
print(len(training_loader), train_config.batch_expand_size)
scheduler = OneCycleLR(optimizer, **{
"steps_per_epoch": len(training_loader) * train_config.batch_expand_size,
"epochs": train_config.epochs,
"anneal_strategy": "cos",
"max_lr": train_config.learning_rate,
"pct_start": 0.1
})
logger = WanDBWriter(train_config)
train_loop(model, train_config, training_loader, val_loader, fastspeech_loss, logger, optimizer, scheduler)