-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhyperparameter_tuning.py
More file actions
154 lines (137 loc) · 5.36 KB
/
hyperparameter_tuning.py
File metadata and controls
154 lines (137 loc) · 5.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
"""
Hyperparameter tuning script. Determines ideal hyperparameters for algorithms.
Can also be used to run a problem with n different iterations of hyperparameters to
determine the best possible result.
Adapted from
https://github.com/AGnias47/optical-hypertension-detection/blob/main/optuna_study.py
"""
import argparse
import optuna
from optuna.integration.mlflow import MLflowCallback
from src.algorithms.aco.ant_system import AntSystem
from src.algorithms.aco.max_min_ant_system import MaxMinAntSystem
from src.algorithms.q_learning.deep_q_learning import DeepQLearning
from src.algorithms.q_learning.double_q_learning import DoubleQLearning
from src.algorithms.q_learning.q_learning import QLearning
from src.utils.arg_parsing import get_filepath_for_problem
DEFAULT_TRIALS = 10
N_JOBS = 1
class Objective:
def __init__(self, algorithm, problem):
self.algorithm = algorithm
self.problem = problem
self.filepath = get_filepath_for_problem(problem)
class ACOObjective(Objective):
def __init__(self, algorithm, problem, mmas=False):
super().__init__(algorithm, problem)
self.mmas = mmas
def __call__(self, trial: optuna.trial.BaseTrial):
trial.set_user_attr("problem", self.problem)
alpha = trial.suggest_int("alpha", 1, 2)
beta = trial.suggest_int("beta", 2, 5)
iterations = trial.suggest_int("iterations", 100, 2500)
if self.mmas:
rho = trial.suggest_float("rho", 0.01, 0.2)
st = trial.suggest_int("stagnation_tolerance", 20, 350)
kwargs = {"rho": rho, "stagnation_tolerance": st}
else:
rho = trial.suggest_float("rho", 0.3, 0.7)
kwargs = {"rho": rho}
solver = self.algorithm(
filepath=self.filepath,
alpha=alpha,
beta=beta,
iterations=iterations,
**kwargs,
)
(cost, route), total_time = solver.run_tsp()
return cost
class QLearningObjective(Objective):
def __init__(self, algorithm, problem):
super().__init__(algorithm, problem)
def __call__(self, trial: optuna.trial.BaseTrial):
trial.set_user_attr("problem", self.problem)
alpha = trial.suggest_float("alpha", 0.0001, 0.1)
gamma = trial.suggest_float("gamma", 0.0001, 0.99999)
epsilon = trial.suggest_categorical("epsilon", ["e1", "e2", "e3", "e4"])
reward = trial.suggest_categorical("reward", ["r1", "r2", "r3"])
solver = self.algorithm(
filepath=self.filepath,
alpha=alpha,
gamma=gamma,
epsilon_func_key=epsilon,
reward_func_key=reward,
)
(cost, route), total_time = solver.run_tsp()
return cost
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-a", "--algorithm", required=True, choices=["as", "mmas", "q", "dq", "dqn"]
)
parser.add_argument("-p", "--problem", required=True)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"-n",
"--trials",
type=int,
help="Number of trials to run. Cannot be specified with timeout. "
f"If neither are specified, {DEFAULT_TRIALS} trials run.",
)
group.add_argument(
"-t",
"--timeout",
type=int,
help="Time to run before stopping. Cannot be specified with trials. "
f"If neither are specified, {DEFAULT_TRIALS} trials run.",
)
args = parser.parse_args()
if args.trials:
kwargs = {"n_trials": args.trials}
elif args.timeout:
kwargs = {"timeout": args.timeout}
else:
kwargs = {"n_trials": DEFAULT_TRIALS}
if args.algorithm == "as":
print("Running Ant System Study")
study = optuna.create_study(study_name="Ant System Hyperparameter Tuning")
objective = ACOObjective(AntSystem, args.problem)
elif args.algorithm == "mmas":
print("Running Max-Min Ant System Study")
study = optuna.create_study(
study_name="Max-Min Ant System Hyperparameter Tuning"
)
objective = ACOObjective(MaxMinAntSystem, args.problem, mmas=True)
elif args.algorithm == "q":
print("Running Q-Learning Study")
study = optuna.create_study(study_name="Q-Learning Hyperparameter Tuning")
objective = QLearningObjective(QLearning, args.problem)
elif args.algorithm == "dq":
print("Running Double Q-Learning Study")
study = optuna.create_study(
study_name="Double Q-Learning Hyperparameter Tuning"
)
objective = QLearningObjective(DoubleQLearning, args.problem)
elif args.algorithm == "dqn":
print("Running Deep Q-Learning Study")
study = optuna.create_study(study_name="Deep Q-Learning Hyperparameter Tuning")
objective = QLearningObjective(DeepQLearning, args.problem)
else:
raise ValueError("Invalid algorithm specified")
try:
study.optimize(
func=objective,
n_jobs=N_JOBS,
callbacks=[MLflowCallback(metric_name="cost")],
**kwargs,
)
except KeyboardInterrupt:
print("Ending study")
print("Optuna study best trial:")
trial = study.best_trial
cost = trial.value
print(f"Cost: {cost}")
print("Params: ")
for key, value in trial.params.items():
print(f"{key}: {value}")