-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy path30train_LSTM.py
More file actions
32 lines (24 loc) · 853 Bytes
/
30train_LSTM.py
File metadata and controls
32 lines (24 loc) · 853 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
import tensorflow as tf
from tools import general
from tools.integration import train_val
if __name__ == '__main__':
opts = general.load_config()
method = "LSTM"
distance_int = 3
data_name = 'kyoto7'
calculation_unit = "0"
# 修订论文新增加的参数
data_max_lenght = 300 # 数据长度, 需要探索
dict_config_cus = {
'model_name': opts[method]["model_name"],
'optimizer': opts[method]["optimizer"],
'distance_int': distance_int,
'data_name': data_name,
'calculation_unit': calculation_unit,
'epochs': opts[method]["epochs"],
'identifier': opts["public"]["identifier"],
'purpose': opts["public"]["purpose"],
# 'datasetsNames': ['cairo', 'milan', 'kyoto7', 'kyoto8', 'kyoto11'],
}
train_val(dict_config_cus, opts)