-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain-svm.py
More file actions
39 lines (31 loc) · 838 Bytes
/
train-svm.py
File metadata and controls
39 lines (31 loc) · 838 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import sys
import dataset_mapper
import classifier_mapper
import svm
#Sanity check
if (len(sys.argv)) < 3:
print("Usage: " + sys.argv[0] + " <dataset> <output> <kernel> [<gamma> <degree> <c> <C>]")
sys.exit(-1)
# Reads parameters
dataset_path = sys.argv[1]
output_path = sys.argv[2]
kernel = sys.argv[3]
gamma = 1.0
if len(sys.argv) > 4:
gamma = float(sys.argv[4])
degree = 2
if len(sys.argv) > 5:
degree = int(sys.argv[5])
c = 0
if len(sys.argv)> 6:
c = float(sys.argv[6])
C = 5
if len(sys.argv) > 7:
C = float(sys.argv[7])
# Trains model
dataset_mapper = dataset_mapper.DatasetMapper()
x, y = dataset_mapper.read(dataset_path)
trainer = svm.SVM(kernel, gamma, degree, c, C)
model = trainer.train(x, y)
classifier_mapper = classifier_mapper.ClassifierMapper()
classifier_mapper.create(model, output_path)