Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions imitation/config/eval.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
defaults:
- _self_
- task: lift_graph
- policy: graph_diffusion_policy
- policy: mog_graph_diffusion_policy

render: False
render: True
output_video: True

num_episodes: 20
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
_target_: imitation.policy.ar_graph_diffusion_policy.AutoregressiveGraphDiffusionPolicy
_target_: imitation.policy.mog_ar_graph_diffusion_policy.StochAutoregressiveGraphDiffusionPolicy
dataset: ${task.dataset}
node_feature_dim: 9 # 9 when task-joint-space, 8 otherwise
num_edge_types: 2 # robot joints, object-robot
lr: 0.0005
ckpt_path: ./weights/${task.task_name}_${task.dataset_type}_film_graph_diffusion_policy.pt
ckpt_path: ./weights/${task.task_name}_${task.dataset_type}_mog_graph_diffusion_policy.pt
device: cuda
denoising_network:
_target_: imitation.model.graph_diffusion.FiLMConditionalGraphDenoisingNetwork
_target_: imitation.model.graph_diffusion.MoGConditionalGraphDenoisingNetwork
node_feature_dim: ${policy.node_feature_dim}
obs_horizon: ${obs_horizon}
pred_horizon: ${pred_horizon}
edge_feature_dim: 1
num_edge_types: ${policy.num_edge_types}
num_layers: 5
hidden_dim: 256
num_mixtures: 1
2 changes: 1 addition & 1 deletion imitation/config/train.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- _self_
- task: lift_graph
- policy: graph_diffusion_policy
- policy: mog_graph_diffusion_policy

output_dir: ./outputs
# on evaluating, for environment wrapper
Expand Down
62 changes: 44 additions & 18 deletions imitation/model/graph_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def forward(self, x, edge_index, edge_attr, v_t=None, h_v=None):
return node_pred, p_e, h_v


class FiLMConditionalGraphDenoisingNetwork(nn.Module):
class MoGConditionalGraphDenoisingNetwork(nn.Module):
def __init__(self,
node_feature_dim,
obs_horizon,
Expand All @@ -169,6 +169,7 @@ def __init__(self,
num_edge_types,
num_layers=5,
hidden_dim=256,
num_mixtures=1,
device=None):
'''
Denoising GNN (based on GraphARM) with FiLM conditioning on
Expand All @@ -185,11 +186,17 @@ def __init__(self,
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.hidden_dim = hidden_dim
self.node_embedding = Linear(node_feature_dim*pred_horizon, hidden_dim).to(self.device)
self.edge_embedding = Linear(edge_feature_dim, hidden_dim).to(self.device)
assert num_mixtures & (num_mixtures - 1) == 0, "num_mixtures should be a power of 2"
self.num_mixtures = num_mixtures

# Node embedding
self.node_embedding = Linear(node_feature_dim*pred_horizon, self.hidden_dim).to(self.device)
# Edge embedding
self.edge_embedding = Linear(edge_feature_dim, self.hidden_dim).to(self.device)

# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
self.cond_channels = hidden_dim * 2
self.cond_channels = self.hidden_dim * 2
self.cond_encoders = nn.ModuleList()
for _ in range(num_layers):
self.cond_encoders.append(nn.Sequential(
Expand All @@ -198,20 +205,23 @@ def __init__(self,
nn.Unflatten(-1, (-1, 1))
).to(self.device))



# Graph layers
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(MPLayer(hidden_dim, hidden_dim)).to(self.device)

self.node_pred_layer = nn.Sequential(Linear(2 * hidden_dim, hidden_dim),
self.layers.append(MPLayer(self.hidden_dim, self.hidden_dim).to(self.device))

# Prediction layer (adjusted for multiple mixtures)
self.node_pred_layer = nn.Sequential(
Linear(2 * self.hidden_dim, self.hidden_dim),
nn.ReLU(),
Linear(hidden_dim, hidden_dim),
Linear(self.hidden_dim, self.hidden_dim),
nn.ReLU(),
Linear(hidden_dim, self.node_feature_dim*self.pred_horizon)
Linear(self.hidden_dim, 2 * self.node_feature_dim * self.pred_horizon * self.num_mixtures + self.num_mixtures) # 2 * hidden_dim for means and variances, and num_mixtures for mixing weights
).to(self.device)

# Positional encoding
self.pe = self.positionalencoding(100).to(self.device) # max length of 100


def positionalencoding(self, lengths):
'''
Expand Down Expand Up @@ -264,14 +274,30 @@ def forward(self, x, edge_index, edge_attr, cond=None, node_order=None):
# repeat graph embedding to have the same shape as h_v
graph_embedding = graph_embedding.repeat(h_v.shape[0], 1)

node_pred = self.node_pred_layer(torch.cat([graph_embedding, h_v], dim=1)) # hidden_dim + 1
# Modified output layer to predict distribution parameters
dist_params = self.node_pred_layer(torch.cat([graph_embedding, h_v], dim=1))

# Split the output into means, variances, and mixing weights
means = dist_params[:, :self.node_feature_dim * self.pred_horizon * self.num_mixtures]
variances = dist_params[:, self.node_feature_dim * self.pred_horizon * self.num_mixtures:2 * self.node_feature_dim * self.pred_horizon * self.num_mixtures]
mixing_weights = dist_params[:, 2 * self.node_feature_dim * self.pred_horizon * self.num_mixtures:]

# softmax the mixing weights and variances (to ensure positivity)
mixing_weights = F.softmax(mixing_weights, dim=-1)
variances = F.softplus(variances)

# Reshape the output to match the number of mixtures
means = means.view(-1, self.pred_horizon, self.node_feature_dim, self.num_mixtures)
variances = variances.view(-1, self.pred_horizon, self.node_feature_dim, self.num_mixtures)
mixing_weights = mixing_weights.view(-1, self.num_mixtures)

# get only last node
v_t = h_v.shape[0] - 1 # node being masked, this assumes that the masked node is the last node in the graph

node_pred = node_pred[v_t]
node_pred = node_pred.reshape(self.pred_horizon, self.node_feature_dim) # reshape to original shape
return node_pred
means = means[v_t]
variances = variances[v_t]
mixing_weights = mixing_weights[v_t]

return means, variances, mixing_weights



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from imitation.utils.graph_diffusion import NodeMasker


class AutoregressiveGraphDiffusionPolicy(nn.Module):
class StochAutoregressiveGraphDiffusionPolicy(nn.Module):
def __init__(self,
dataset,
node_feature_dim,
Expand All @@ -20,7 +20,7 @@ def __init__(self,
ckpt_path=None,
device = None,
mode = 'joint-space'):
super(AutoregressiveGraphDiffusionPolicy, self).__init__()
super(StochAutoregressiveGraphDiffusionPolicy, self).__init__()
if device == None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
Expand All @@ -35,7 +35,7 @@ def __init__(self,
self.global_epoch = 0

self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=50, factor=0.5, verbose=True, min_lr=lr/100)
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=200, factor=0.6, verbose=True, min_lr=lr/10)
self.mode = mode

if ckpt_path is not None:
Expand Down Expand Up @@ -100,22 +100,87 @@ def node_decay_ordering(self, graph):
'''
return torch.arange(graph.x.shape[0]-1, -1, -1)

def vlb(self, G_0, edge_type_probs, node, node_order, t):
T = len(node_order)
def nll_loss(self, G_0, dist_params, node):
n_i = G_0.x.shape[0]
# retrieve edge type from G_t.edge_attr, edges between node and node_order[t:]
edge_attrs_matrix = G_0.edge_attr.reshape(T, T)
original_edge_types = torch.index_select(edge_attrs_matrix[node], 0, torch.tensor(node_order[t:]).to(self.device))
# calculate probability of edge type
p_edges = torch.gather(edge_type_probs, 1, original_edge_types.reshape(-1, 1))
log_p_edges = torch.sum(torch.log(p_edges))
# log_p_edges = torch.sum(torch.tensor([0]))
wandb.log({"target_edges_log_prob": log_p_edges})
# calculate loss
log_p_O_v = log_p_edges
loss = -(n_i/T)*log_p_O_v # cumulative, to join (k) from all previously denoised nodes
# get likelihood of joint values
log_likelihood = self.get_distribution_likelihood(dist_params, G_0.x[node,:,:])
loss = - log_likelihood
return loss


def get_distribution_likelihood(self, dist_params, joint_values):
'''
Returns the likelihood of joint values given a Gaussian Mixture Model (GMM) distribution.
Args:
dist_params: A tensor containing parameters for the MoG distribution.
This should be in the format [means, variances, mixing_weights],
where means and variances are shaped [pred_horizon, node_feature_dim, num_mixtures],
and mixing_weights are shaped [num_mixtures].
joint_values: A tensor containing joint values shaped [pred_horizon, node_feature_dim].
'''
# Extract parameters
means = dist_params[0].to(self.device).double()
# variances = dist_params[1].to(self.device).double()
variances = torch.zeros_like(means).to(self.device).double() + 0.01
# mixing_weights = dist_params[2].to(self.device).double()
mixing_weights = dist_params[2].to(self.device).double()

# Reshape mixing_weights for broadcasting
mixing_weights = mixing_weights.unsqueeze(0).unsqueeze(0) # Add two dimensions to the front

# Calculate squared differences and normalize by variances
repeated_joint_values = joint_values.unsqueeze(2).repeat(1, 1, mixing_weights.shape[2])
squared_diffs = (repeated_joint_values - means) ** 2 / variances

# Calculate exponential terms
exp_terms = torch.exp(-0.5 * squared_diffs)

# Calculate likelihoods and sum over mixtures
likelihoods = exp_terms / torch.sqrt(2 * np.pi * variances)
# Multiply by mixing weights
likelihoods = likelihoods * mixing_weights
likelihood_sum = torch.sum(likelihoods, dim=2) # sum over mixtures

# avoid numerical instability
likelihood_sum = torch.clamp(likelihood_sum, min=1e-20)

# Calculate and return log-likelihood
return torch.sum(torch.log(likelihood_sum))


def sample_from_distribution(self, dist_params):
'''
Samples joint values from a Mixture of Gaussians (MoG) distribution.
Args:
dist_params: A tensor containing parameters for the MoG distribution.
This should be in the format [means, variances, mixing_weights],
where means and variances are shaped [pred_horizon, node_feature_dim, num_mixtures],
and mixing_weights are shaped [num_mixtures].

Returns:
joint_values: A tensor containing sampled joint values shaped
[pred_horizon, num_features].
'''

# Extract parameters
means = dist_params[0]
variances = dist_params[1]
mixing_weights = dist_params[2]

# Get the index of the maximum mixing weight
max_mixture_idx = torch.argmax(mixing_weights)

# Sample from Gaussian using the maximum mixing weight
pred_horizon = means.shape[0]
num_features = means.shape[1]
joint_values = torch.zeros(pred_horizon, num_features) # single joint/node
for i in range(pred_horizon):
for j in range(num_features):
# Sample from Gaussian using the maximum mixing weight
joint_values[i, j] = torch.normal(means[i, j, max_mixture_idx], torch.sqrt(variances[i, j, max_mixture_idx]))

return joint_values

def train(self, dataset, num_epochs=100, model_path=None, seed=0):
'''
Train noise prediction model
Expand All @@ -126,7 +191,7 @@ def train(self, dataset, num_epochs=100, model_path=None, seed=0):
pass
self.optimizer.zero_grad()
self.model.train()
batch_size = 5
batch_size = 1

with tqdm(range(num_epochs), desc='Epoch', leave=False) as tepoch:
for epoch in tepoch:
Expand All @@ -148,21 +213,20 @@ def train(self, dataset, num_epochs=100, model_path=None, seed=0):
G_pred = diffusion_trajectory[t+1].clone().to(self.device)

# predict node and edge type distributions
joint_values = self.model(G_pred.x.float(), G_pred.edge_index, G_pred.edge_attr.float(), cond=G_0.y.float())
logits = self.model(G_pred.x.float(), G_pred.edge_index, G_pred.edge_attr.float(), cond=G_0.y.float())

# joint_values = node_features[0] # first node feature is joint values
# mse loss for node features
loss = self.node_feature_loss(joint_values, G_0.x[node_order[t],:,:].float())
# loss = self.node_feature_loss(joint_values, G_0.x[node_order[t],:,:].float())
loss = self.nll_loss(G_0, logits, node_order[t])

# TODO add loss for absolute positions, to make the model physics-informed

# use correlation as loss
# x = torch.stack([node_features.squeeze(), G_0.x[node_order[t],:,0].float()])
# x += torch.rand_like(x) * 1e-8 # add noise to avoid NaNs
# loss -= (1/batch_size)*torch.corrcoef(x)[0,1]
wandb.log({"epoch": self.global_epoch, "loss": loss.item()})
wandb.log({"epoch": self.global_epoch, "nll_loss": loss.item()})

acc_loss += loss.item()
# backprop (accumulated gradients)
loss.backward(retain_graph=True)
loss.backward()
batch_i += 1
# update weights
if batch_i % batch_size == 0:
Expand Down Expand Up @@ -226,7 +290,8 @@ def get_action(self, obs):
for x_i in range(obs[0].x.shape[0]): # number of nodes in action graph TODO remove objects
action = self.preprocess(action)
# predict node attributes for last node in action
action.x[-1] = self.model(action.x.float(), action.edge_index, action.edge_attr, cond=node_features)
logits = self.model(action.x.float(), action.edge_index, action.edge_attr, cond=node_features)
action.x[-1] = self.sample_from_distribution(logits)
# map edge attributes from obs to action
action.edge_attr = self._lookup_edge_attr(edge_index, edge_attr, action.edge_index)
if x_i == obs[0].x.shape[0]-1:
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def train(cfg: DictConfig) -> None:
wandb.init(
project=policy.__class__.__name__,
group=cfg.task.task_name,
name=f"v1.0.2",
name=f"v1.0.2 - fixed variance",
# track hyperparameters and run metadata
config={
"policy": cfg.policy,
Expand Down