forked from AgentTorch/AgentTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample.py
More file actions
89 lines (65 loc) · 2.44 KB
/
example.py
File metadata and controls
89 lines (65 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from agent_torch.core.executor import Executor
from agent_torch.core.dataloader import LoadPopulation
from agent_torch.models import covid
from agent_torch.populations import astoria
from custom_population import customize
import operator
from functools import reduce
import torch
def set_params(runner, input_string, new_value):
tensor_func = map_and_replace_tensor(input_string)
current_tensor = tensor_func(runner, new_value)
def map_and_replace_tensor(input_string):
# Split the input string into its components
parts = input_string.split(".")
# Extract the relevant parts
function = parts[1]
index = parts[2]
sub_func = parts[3]
arg_type = parts[4]
var_name = parts[5]
def getter_and_setter(runner, new_value=None):
current = runner
substep_type = getattr(runner.initializer, function)
substep_function = getattr(substep_type[str(index)], sub_func)
current_tensor = getattr(substep_function, "calibrate_" + var_name)
print("Current value: ", current_tensor)
if new_value is not None:
assert new_value.requires_grad == current_tensor.requires_grad
setvar_name = "calibrate_" + var_name
setattr(substep_function, setvar_name, new_value)
current_tensor = getattr(substep_function, "calibrate_" + var_name)
return current_tensor
else:
return current_tensor
return getter_and_setter
def setup(model, population):
loader = LoadPopulation(population)
simulation = Executor(model=model, pop_loader=loader)
runner = simulation.runner
runner.init()
return runner
def simulate(runner):
num_steps_per_episode = runner.config["simulation_metadata"][
"num_steps_per_episode"
]
runner.step(num_steps_per_episode)
traj = runner.state_trajectory[-1][-1]
preds = traj["environment"]["daily_infected"]
loss = preds.sum()
return loss
runner = setup(covid, astoria)
learn_params = [(name, params) for (name, params) in runner.named_parameters()]
new_tensor = torch.tensor([3.5, 4.2, 5.6], requires_grad=True)
input_string = learn_params[0][0]
input_string = "initializer.transition_function.0.new_transmission.learnable_args.R2"
params_dict = {input_string: new_tensor}
runner._set_parameters(params_dict)
# set_params(runner, input_string, new_tensor)
"""
Tasks to do:
1. Custom population size
2. Init Infections
3. Set parameters
4. Visualize values
"""