-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimulator_utils.py
More file actions
101 lines (89 loc) · 3.09 KB
/
simulator_utils.py
File metadata and controls
101 lines (89 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from abc import ABC, abstractmethod
import torch
from tqdm import tqdm
from common import ConditionalVectorField
class ODE(ABC):
@abstractmethod
def drift_coefficient(
self, xt: torch.Tensor, t: torch.Tensor, **kwargs
) -> torch.Tensor:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1)
Returns:
- drift_coefficient: shape (bs, c, h, w)
"""
pass
class Simulator(ABC):
@abstractmethod
def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs):
"""
Takes one simulation step
Args:
- xt: state at time t, shape (bs, c, h, w)
- t: time, shape (bs, 1, 1, 1)
- dt: time, shape (bs, 1, 1, 1)
Returns:
- nxt: state at time t + dt (bs, c, h, w)
"""
pass
@torch.no_grad()
def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x_init: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- x_final: final state at time ts[-1], shape (bs, c, h, w)
"""
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:, t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
return x
@torch.no_grad()
def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
"""
Simulates using the discretization gives by ts
Args:
- x: initial state, shape (bs, c, h, w)
- ts: timesteps, shape (bs, nts, 1, 1, 1)
Returns:
- xs: trajectory of xts over ts, shape (batch_size, nts, c, h, w)
"""
xs = [x.clone()]
nts = ts.shape[1]
for t_idx in tqdm(range(nts - 1)):
t = ts[:, t_idx]
h = ts[:, t_idx + 1] - ts[:, t_idx]
x = self.step(x, t, h, **kwargs)
xs.append(x.clone())
return torch.stack(xs, dim=1)
class EulerSimulator(Simulator):
def __init__(self, ode: ODE):
self.ode = ode
def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor, **kwargs):
return xt + self.ode.drift_coefficient(xt, t, **kwargs) * h
class CFGVectorFieldODE(ODE):
def __init__(self, net: ConditionalVectorField, guidance_scale: float = 1.0):
self.net = net
self.guidance_scale = guidance_scale
def drift_coefficient(
self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor
) -> torch.Tensor:
"""
Args:
- x: (bs, c, h, w)
- t: (bs, 1, 1, 1)
- y: (bs,)
"""
guided_vector_field = self.net(x, t, y)
unguided_y = torch.ones_like(y) * 10
unguided_vector_field = self.net(x, t, unguided_y)
return (
1 - self.guidance_scale
) * unguided_vector_field + self.guidance_scale * guided_vector_field