From 5801863516784276e1f16fe840661cb4e02297ae Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 19 Feb 2024 14:13:21 +0100 Subject: [PATCH 1/5] optimize based on likelihood using MoG distribution --- .../config/policy/graph_diffusion_policy.yaml | 8 +- imitation/model/graph_diffusion.py | 58 +++++++--- imitation/policy/ar_graph_diffusion_policy.py | 108 ++++++++++++++---- 3 files changed, 128 insertions(+), 46 deletions(-) diff --git a/imitation/config/policy/graph_diffusion_policy.yaml b/imitation/config/policy/graph_diffusion_policy.yaml index 5e463a7..2cd7905 100644 --- a/imitation/config/policy/graph_diffusion_policy.yaml +++ b/imitation/config/policy/graph_diffusion_policy.yaml @@ -1,12 +1,12 @@ -_target_: imitation.policy.ar_graph_diffusion_policy.AutoregressiveGraphDiffusionPolicy +_target_: imitation.policy.ar_graph_diffusion_policy.StochAutoregressiveGraphDiffusionPolicy dataset: ${task.dataset} node_feature_dim: 8 # 9 when task-joint-space, 8 otherwise num_edge_types: 2 # robot joints, object-robot -lr: 0.0005 -ckpt_path: ./weights/lift_graph_diffusion_policy.pt +lr: 0.00005 +ckpt_path: /home/caio/workspace/GraphDiffusionImitate/weights/lift_graph_diffusion_policy.pt device: cuda denoising_network: - _target_: imitation.model.graph_diffusion.ConditionalGraphDenoisingNetwork + _target_: imitation.model.graph_diffusion.FiLMConditionalGraphDenoisingNetwork node_feature_dim: ${policy.node_feature_dim} obs_horizon: ${obs_horizon} pred_horizon: ${pred_horizon} diff --git a/imitation/model/graph_diffusion.py b/imitation/model/graph_diffusion.py index 72e32e1..f6e016c 100644 --- a/imitation/model/graph_diffusion.py +++ b/imitation/model/graph_diffusion.py @@ -185,11 +185,16 @@ def __init__(self, self.obs_horizon = obs_horizon self.pred_horizon = pred_horizon self.hidden_dim = hidden_dim - self.node_embedding = Linear(node_feature_dim*pred_horizon, hidden_dim).to(self.device) - self.edge_embedding = Linear(edge_feature_dim, hidden_dim).to(self.device) + self.num_mixtures = 2 # make sure it's power of 2 + + # Node embedding + self.node_embedding = Linear(node_feature_dim*pred_horizon, self.hidden_dim).to(self.device) + # Edge embedding + self.edge_embedding = Linear(edge_feature_dim, self.hidden_dim).to(self.device) + # FiLM modulation https://arxiv.org/abs/1709.07871 # predicts per-channel scale and bias - self.cond_channels = hidden_dim * 2 + self.cond_channels = self.hidden_dim * 2 self.cond_encoders = nn.ModuleList() for _ in range(num_layers): self.cond_encoders.append(nn.Sequential( @@ -198,20 +203,23 @@ def __init__(self, nn.Unflatten(-1, (-1, 1)) ).to(self.device)) - - + # Graph layers self.layers = nn.ModuleList() for _ in range(num_layers): - self.layers.append(MPLayer(hidden_dim, hidden_dim)).to(self.device) - - self.node_pred_layer = nn.Sequential(Linear(2 * hidden_dim, hidden_dim), + self.layers.append(MPLayer(self.hidden_dim, self.hidden_dim).to(self.device)) + + # Prediction layer (adjusted for multiple mixtures) + self.node_pred_layer = nn.Sequential( + Linear(2 * self.hidden_dim, self.hidden_dim), nn.ReLU(), - Linear(hidden_dim, hidden_dim), + Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(), - Linear(hidden_dim, self.node_feature_dim*self.pred_horizon) + Linear(self.hidden_dim, 2 * self.node_feature_dim * self.pred_horizon * self.num_mixtures + self.num_mixtures) # 2 * hidden_dim for means and variances, and num_mixtures for mixing weights ).to(self.device) + + # Positional encoding self.pe = self.positionalencoding(100).to(self.device) # max length of 100 - + def positionalencoding(self, lengths): ''' @@ -264,14 +272,30 @@ def forward(self, x, edge_index, edge_attr, cond=None, node_order=None): # repeat graph embedding to have the same shape as h_v graph_embedding = graph_embedding.repeat(h_v.shape[0], 1) - node_pred = self.node_pred_layer(torch.cat([graph_embedding, h_v], dim=1)) # hidden_dim + 1 + # Modified output layer to predict distribution parameters + dist_params = self.node_pred_layer(torch.cat([graph_embedding, h_v], dim=1)) + + # Split the output into means, variances, and mixing weights + means = dist_params[:, :self.node_feature_dim * self.pred_horizon * self.num_mixtures] + variances = dist_params[:, self.node_feature_dim * self.pred_horizon * self.num_mixtures:2 * self.node_feature_dim * self.pred_horizon * self.num_mixtures] + mixing_weights = dist_params[:, 2 * self.node_feature_dim * self.pred_horizon * self.num_mixtures:] + + # softmax the mixing weights and variances (to ensure positivity) + mixing_weights = F.softmax(mixing_weights, dim=-1) + variances = F.softplus(variances) + # Reshape the output to match the number of mixtures + means = means.view(-1, self.pred_horizon, self.node_feature_dim, self.num_mixtures) + variances = variances.view(-1, self.pred_horizon, self.node_feature_dim, self.num_mixtures) + mixing_weights = mixing_weights.view(-1, self.num_mixtures) + + # get only last node v_t = h_v.shape[0] - 1 # node being masked, this assumes that the masked node is the last node in the graph - - node_pred = node_pred[v_t] - node_pred = node_pred.reshape(self.pred_horizon, self.node_feature_dim) # reshape to original shape - - return node_pred + means = means[v_t] + variances = variances[v_t] + mixing_weights = mixing_weights[v_t] + + return means, variances, mixing_weights diff --git a/imitation/policy/ar_graph_diffusion_policy.py b/imitation/policy/ar_graph_diffusion_policy.py index f66d542..2a9c757 100644 --- a/imitation/policy/ar_graph_diffusion_policy.py +++ b/imitation/policy/ar_graph_diffusion_policy.py @@ -10,7 +10,7 @@ from imitation.utils.graph_diffusion import NodeMasker -class AutoregressiveGraphDiffusionPolicy(nn.Module): +class StochAutoregressiveGraphDiffusionPolicy(nn.Module): def __init__(self, dataset, node_feature_dim, @@ -20,7 +20,7 @@ def __init__(self, ckpt_path=None, device = None, mode = 'joint-space'): - super(AutoregressiveGraphDiffusionPolicy, self).__init__() + super(StochAutoregressiveGraphDiffusionPolicy, self).__init__() if device == None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: @@ -100,22 +100,80 @@ def node_decay_ordering(self, graph): ''' return torch.arange(graph.x.shape[0]-1, -1, -1) - def vlb(self, G_0, edge_type_probs, node, node_order, t): - T = len(node_order) + def nll_loss(self, G_0, node_features, node): n_i = G_0.x.shape[0] - # retrieve edge type from G_t.edge_attr, edges between node and node_order[t:] - edge_attrs_matrix = G_0.edge_attr.reshape(T, T) - original_edge_types = torch.index_select(edge_attrs_matrix[node], 0, torch.tensor(node_order[t:]).to(self.device)) - # calculate probability of edge type - p_edges = torch.gather(edge_type_probs, 1, original_edge_types.reshape(-1, 1)) - log_p_edges = torch.sum(torch.log(p_edges)) - # log_p_edges = torch.sum(torch.tensor([0])) - wandb.log({"target_edges_log_prob": log_p_edges}) - # calculate loss - log_p_O_v = log_p_edges - loss = -(n_i/T)*log_p_O_v # cumulative, to join (k) from all previously denoised nodes + # get likelihood of joint values + log_likelihood = self.get_distribution_likelihood(node_features, G_0.x[node,:,:]) + loss = - log_likelihood return loss + + def get_distribution_likelihood(self, dist_params, joint_values): + ''' + Returns the likelihood of joint values given a Mixture of Gaussians (MoG) distribution. + Args: + dist_params: A tensor containing parameters for the MoG distribution. + This should be in the format [means, variances, mixing_weights], + where means and variances are shaped [pred_horizon, node_feature_dim, num_mixtures], + and mixing_weights are shaped [num_mixtures]. + joint_values: A tensor containing joint values shaped [pred_horizon, node_feature_dim]. + ''' + # Extract parameters + means = dist_params[0].to(self.device).double() + variances = dist_params[1].to(self.device).double() + mixing_weights = dist_params[2].to(self.device).double() + + # Reshape mixing_weights for broadcasting + mixing_weights = mixing_weights.unsqueeze(0).unsqueeze(0) # Add two dimensions to the front + + # Calculate squared differences and normalize by variances + repeated_joint_values = joint_values.unsqueeze(2).repeat(1, 1, mixing_weights.shape[2]) + squared_diffs = (repeated_joint_values - means) ** 2 / variances + + # Calculate exponential terms and multiply with mixing weights + exp_terms = torch.exp(-0.5 * squared_diffs) * mixing_weights + + # Calculate likelihoods and sum over mixtures + likelihoods = exp_terms / torch.sqrt(2 * np.pi * variances) + likelihood_sum = torch.sum(likelihoods, dim=2) # sum over mixtures + + # Calculate and return log-likelihood + return torch.sum(torch.log(likelihood_sum)) + + + def sample_from_distribution(self, dist_params): + ''' + Samples joint values from a Mixture of Gaussians (MoG) distribution. + Args: + dist_params: A tensor containing parameters for the MoG distribution. + This should be in the format [means, variances, mixing_weights], + where means and variances are shaped [pred_horizon, node_feature_dim, num_mixtures], + and mixing_weights are shaped [num_mixtures]. + + Returns: + joint_values: A tensor containing sampled joint values shaped + [pred_horizon, num_features]. + ''' + + # Extract parameters + means = dist_params[0] + variances = dist_params[1] + mixing_weights = dist_params[2] + + # Get the index of the maximum mixing weight + max_mixture_idx = torch.argmax(mixing_weights) + + # Sample from Gaussian using the maximum mixing weight + pred_horizon = means.shape[0] + num_features = means.shape[1] + joint_values = torch.zeros(pred_horizon, num_features) # single joint/node + for i in range(pred_horizon): + for j in range(num_features): + # Sample from Gaussian using the maximum mixing weight + joint_values[i, j] = torch.normal(means[i, j, max_mixture_idx], torch.sqrt(variances[i, j, max_mixture_idx])) + + return joint_values + def train(self, dataset, num_epochs=100, model_path=None, seed=0): ''' Train noise prediction model @@ -148,21 +206,20 @@ def train(self, dataset, num_epochs=100, model_path=None, seed=0): G_pred = diffusion_trajectory[t+1].clone().to(self.device) # predict node and edge type distributions - joint_values = self.model(G_pred.x.float(), G_pred.edge_index, G_pred.edge_attr.float(), cond=G_0.y.float()) + logits = self.model(G_pred.x.float(), G_pred.edge_index, G_pred.edge_attr.float(), cond=G_0.y.float()) + # joint_values = node_features[0] # first node feature is joint values # mse loss for node features - loss = self.node_feature_loss(joint_values, G_0.x[node_order[t],:,:].float()) + # loss = self.node_feature_loss(joint_values, G_0.x[node_order[t],:,:].float()) + loss = self.nll_loss(G_0, logits, node_order[t]) + # TODO add loss for absolute positions, to make the model physics-informed - # use correlation as loss - # x = torch.stack([node_features.squeeze(), G_0.x[node_order[t],:,0].float()]) - # x += torch.rand_like(x) * 1e-8 # add noise to avoid NaNs - # loss -= (1/batch_size)*torch.corrcoef(x)[0,1] - wandb.log({"epoch": self.global_epoch, "loss": loss.item()}) + wandb.log({"epoch": self.global_epoch, "nll_loss": loss.item()}) acc_loss += loss.item() # backprop (accumulated gradients) - loss.backward(retain_graph=True) + loss.backward() batch_i += 1 # update weights if batch_i % batch_size == 0: @@ -182,7 +239,7 @@ def get_joint_values(self, x): - end-effector: raise NotImplementedError ''' if self.mode == 'joint-space' or self.mode == 'task-joint-space': - return x[:,:,0].T # all nodes, all timesteps, first value + return x[:8,:,0].T # all nodes, all timesteps, first value elif self.mode == 'end-effector': raise NotImplementedError else: @@ -226,7 +283,8 @@ def get_action(self, obs): for x_i in range(obs[0].x.shape[0]): # number of nodes in action graph action = self.preprocess(action) # predict node attributes for last node in action - action.x[-1] = self.model(action.x.float(), action.edge_index, action.edge_attr, cond=node_features) + logits = self.model(action.x.float(), action.edge_index, action.edge_attr, cond=node_features) + action.x[-1] = self.sample_from_distribution(logits) # map edge attributes from obs to action action.edge_attr = self._lookup_edge_attr(edge_index, edge_attr, action.edge_index) if x_i == obs[0].x.shape[0]-1: From 8ad5cb41e9a7009405f2a071f02d502b44443020 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 19 Feb 2024 14:13:21 +0100 Subject: [PATCH 2/5] optimize based on likelihood using MoG distribution --- .../config/policy/graph_diffusion_policy.yaml | 2 +- imitation/model/graph_diffusion.py | 58 +++++++--- imitation/policy/ar_graph_diffusion_policy.py | 106 ++++++++++++++---- 3 files changed, 124 insertions(+), 42 deletions(-) diff --git a/imitation/config/policy/graph_diffusion_policy.yaml b/imitation/config/policy/graph_diffusion_policy.yaml index fd02413..fd40c0f 100644 --- a/imitation/config/policy/graph_diffusion_policy.yaml +++ b/imitation/config/policy/graph_diffusion_policy.yaml @@ -1,4 +1,4 @@ -_target_: imitation.policy.ar_graph_diffusion_policy.AutoregressiveGraphDiffusionPolicy +_target_: imitation.policy.ar_graph_diffusion_policy.StochAutoregressiveGraphDiffusionPolicy dataset: ${task.dataset} node_feature_dim: 9 # 9 when task-joint-space, 8 otherwise num_edge_types: 2 # robot joints, object-robot diff --git a/imitation/model/graph_diffusion.py b/imitation/model/graph_diffusion.py index 72e32e1..f6e016c 100644 --- a/imitation/model/graph_diffusion.py +++ b/imitation/model/graph_diffusion.py @@ -185,11 +185,16 @@ def __init__(self, self.obs_horizon = obs_horizon self.pred_horizon = pred_horizon self.hidden_dim = hidden_dim - self.node_embedding = Linear(node_feature_dim*pred_horizon, hidden_dim).to(self.device) - self.edge_embedding = Linear(edge_feature_dim, hidden_dim).to(self.device) + self.num_mixtures = 2 # make sure it's power of 2 + + # Node embedding + self.node_embedding = Linear(node_feature_dim*pred_horizon, self.hidden_dim).to(self.device) + # Edge embedding + self.edge_embedding = Linear(edge_feature_dim, self.hidden_dim).to(self.device) + # FiLM modulation https://arxiv.org/abs/1709.07871 # predicts per-channel scale and bias - self.cond_channels = hidden_dim * 2 + self.cond_channels = self.hidden_dim * 2 self.cond_encoders = nn.ModuleList() for _ in range(num_layers): self.cond_encoders.append(nn.Sequential( @@ -198,20 +203,23 @@ def __init__(self, nn.Unflatten(-1, (-1, 1)) ).to(self.device)) - - + # Graph layers self.layers = nn.ModuleList() for _ in range(num_layers): - self.layers.append(MPLayer(hidden_dim, hidden_dim)).to(self.device) - - self.node_pred_layer = nn.Sequential(Linear(2 * hidden_dim, hidden_dim), + self.layers.append(MPLayer(self.hidden_dim, self.hidden_dim).to(self.device)) + + # Prediction layer (adjusted for multiple mixtures) + self.node_pred_layer = nn.Sequential( + Linear(2 * self.hidden_dim, self.hidden_dim), nn.ReLU(), - Linear(hidden_dim, hidden_dim), + Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(), - Linear(hidden_dim, self.node_feature_dim*self.pred_horizon) + Linear(self.hidden_dim, 2 * self.node_feature_dim * self.pred_horizon * self.num_mixtures + self.num_mixtures) # 2 * hidden_dim for means and variances, and num_mixtures for mixing weights ).to(self.device) + + # Positional encoding self.pe = self.positionalencoding(100).to(self.device) # max length of 100 - + def positionalencoding(self, lengths): ''' @@ -264,14 +272,30 @@ def forward(self, x, edge_index, edge_attr, cond=None, node_order=None): # repeat graph embedding to have the same shape as h_v graph_embedding = graph_embedding.repeat(h_v.shape[0], 1) - node_pred = self.node_pred_layer(torch.cat([graph_embedding, h_v], dim=1)) # hidden_dim + 1 + # Modified output layer to predict distribution parameters + dist_params = self.node_pred_layer(torch.cat([graph_embedding, h_v], dim=1)) + + # Split the output into means, variances, and mixing weights + means = dist_params[:, :self.node_feature_dim * self.pred_horizon * self.num_mixtures] + variances = dist_params[:, self.node_feature_dim * self.pred_horizon * self.num_mixtures:2 * self.node_feature_dim * self.pred_horizon * self.num_mixtures] + mixing_weights = dist_params[:, 2 * self.node_feature_dim * self.pred_horizon * self.num_mixtures:] + + # softmax the mixing weights and variances (to ensure positivity) + mixing_weights = F.softmax(mixing_weights, dim=-1) + variances = F.softplus(variances) + # Reshape the output to match the number of mixtures + means = means.view(-1, self.pred_horizon, self.node_feature_dim, self.num_mixtures) + variances = variances.view(-1, self.pred_horizon, self.node_feature_dim, self.num_mixtures) + mixing_weights = mixing_weights.view(-1, self.num_mixtures) + + # get only last node v_t = h_v.shape[0] - 1 # node being masked, this assumes that the masked node is the last node in the graph - - node_pred = node_pred[v_t] - node_pred = node_pred.reshape(self.pred_horizon, self.node_feature_dim) # reshape to original shape - - return node_pred + means = means[v_t] + variances = variances[v_t] + mixing_weights = mixing_weights[v_t] + + return means, variances, mixing_weights diff --git a/imitation/policy/ar_graph_diffusion_policy.py b/imitation/policy/ar_graph_diffusion_policy.py index d6ba7ed..d44d32c 100644 --- a/imitation/policy/ar_graph_diffusion_policy.py +++ b/imitation/policy/ar_graph_diffusion_policy.py @@ -10,7 +10,7 @@ from imitation.utils.graph_diffusion import NodeMasker -class AutoregressiveGraphDiffusionPolicy(nn.Module): +class StochAutoregressiveGraphDiffusionPolicy(nn.Module): def __init__(self, dataset, node_feature_dim, @@ -20,7 +20,7 @@ def __init__(self, ckpt_path=None, device = None, mode = 'joint-space'): - super(AutoregressiveGraphDiffusionPolicy, self).__init__() + super(StochAutoregressiveGraphDiffusionPolicy, self).__init__() if device == None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: @@ -100,22 +100,80 @@ def node_decay_ordering(self, graph): ''' return torch.arange(graph.x.shape[0]-1, -1, -1) - def vlb(self, G_0, edge_type_probs, node, node_order, t): - T = len(node_order) + def nll_loss(self, G_0, node_features, node): n_i = G_0.x.shape[0] - # retrieve edge type from G_t.edge_attr, edges between node and node_order[t:] - edge_attrs_matrix = G_0.edge_attr.reshape(T, T) - original_edge_types = torch.index_select(edge_attrs_matrix[node], 0, torch.tensor(node_order[t:]).to(self.device)) - # calculate probability of edge type - p_edges = torch.gather(edge_type_probs, 1, original_edge_types.reshape(-1, 1)) - log_p_edges = torch.sum(torch.log(p_edges)) - # log_p_edges = torch.sum(torch.tensor([0])) - wandb.log({"target_edges_log_prob": log_p_edges}) - # calculate loss - log_p_O_v = log_p_edges - loss = -(n_i/T)*log_p_O_v # cumulative, to join (k) from all previously denoised nodes + # get likelihood of joint values + log_likelihood = self.get_distribution_likelihood(node_features, G_0.x[node,:,:]) + loss = - log_likelihood return loss + + def get_distribution_likelihood(self, dist_params, joint_values): + ''' + Returns the likelihood of joint values given a Mixture of Gaussians (MoG) distribution. + Args: + dist_params: A tensor containing parameters for the MoG distribution. + This should be in the format [means, variances, mixing_weights], + where means and variances are shaped [pred_horizon, node_feature_dim, num_mixtures], + and mixing_weights are shaped [num_mixtures]. + joint_values: A tensor containing joint values shaped [pred_horizon, node_feature_dim]. + ''' + # Extract parameters + means = dist_params[0].to(self.device).double() + variances = dist_params[1].to(self.device).double() + mixing_weights = dist_params[2].to(self.device).double() + + # Reshape mixing_weights for broadcasting + mixing_weights = mixing_weights.unsqueeze(0).unsqueeze(0) # Add two dimensions to the front + + # Calculate squared differences and normalize by variances + repeated_joint_values = joint_values.unsqueeze(2).repeat(1, 1, mixing_weights.shape[2]) + squared_diffs = (repeated_joint_values - means) ** 2 / variances + + # Calculate exponential terms and multiply with mixing weights + exp_terms = torch.exp(-0.5 * squared_diffs) * mixing_weights + + # Calculate likelihoods and sum over mixtures + likelihoods = exp_terms / torch.sqrt(2 * np.pi * variances) + likelihood_sum = torch.sum(likelihoods, dim=2) # sum over mixtures + + # Calculate and return log-likelihood + return torch.sum(torch.log(likelihood_sum)) + + + def sample_from_distribution(self, dist_params): + ''' + Samples joint values from a Mixture of Gaussians (MoG) distribution. + Args: + dist_params: A tensor containing parameters for the MoG distribution. + This should be in the format [means, variances, mixing_weights], + where means and variances are shaped [pred_horizon, node_feature_dim, num_mixtures], + and mixing_weights are shaped [num_mixtures]. + + Returns: + joint_values: A tensor containing sampled joint values shaped + [pred_horizon, num_features]. + ''' + + # Extract parameters + means = dist_params[0] + variances = dist_params[1] + mixing_weights = dist_params[2] + + # Get the index of the maximum mixing weight + max_mixture_idx = torch.argmax(mixing_weights) + + # Sample from Gaussian using the maximum mixing weight + pred_horizon = means.shape[0] + num_features = means.shape[1] + joint_values = torch.zeros(pred_horizon, num_features) # single joint/node + for i in range(pred_horizon): + for j in range(num_features): + # Sample from Gaussian using the maximum mixing weight + joint_values[i, j] = torch.normal(means[i, j, max_mixture_idx], torch.sqrt(variances[i, j, max_mixture_idx])) + + return joint_values + def train(self, dataset, num_epochs=100, model_path=None, seed=0): ''' Train noise prediction model @@ -148,21 +206,20 @@ def train(self, dataset, num_epochs=100, model_path=None, seed=0): G_pred = diffusion_trajectory[t+1].clone().to(self.device) # predict node and edge type distributions - joint_values = self.model(G_pred.x.float(), G_pred.edge_index, G_pred.edge_attr.float(), cond=G_0.y.float()) + logits = self.model(G_pred.x.float(), G_pred.edge_index, G_pred.edge_attr.float(), cond=G_0.y.float()) + # joint_values = node_features[0] # first node feature is joint values # mse loss for node features - loss = self.node_feature_loss(joint_values, G_0.x[node_order[t],:,:].float()) + # loss = self.node_feature_loss(joint_values, G_0.x[node_order[t],:,:].float()) + loss = self.nll_loss(G_0, logits, node_order[t]) + # TODO add loss for absolute positions, to make the model physics-informed - # use correlation as loss - # x = torch.stack([node_features.squeeze(), G_0.x[node_order[t],:,0].float()]) - # x += torch.rand_like(x) * 1e-8 # add noise to avoid NaNs - # loss -= (1/batch_size)*torch.corrcoef(x)[0,1] - wandb.log({"epoch": self.global_epoch, "loss": loss.item()}) + wandb.log({"epoch": self.global_epoch, "nll_loss": loss.item()}) acc_loss += loss.item() # backprop (accumulated gradients) - loss.backward(retain_graph=True) + loss.backward() batch_i += 1 # update weights if batch_i % batch_size == 0: @@ -226,7 +283,8 @@ def get_action(self, obs): for x_i in range(obs[0].x.shape[0]): # number of nodes in action graph TODO remove objects action = self.preprocess(action) # predict node attributes for last node in action - action.x[-1] = self.model(action.x.float(), action.edge_index, action.edge_attr, cond=node_features) + logits = self.model(action.x.float(), action.edge_index, action.edge_attr, cond=node_features) + action.x[-1] = self.sample_from_distribution(logits) # map edge attributes from obs to action action.edge_attr = self._lookup_edge_attr(edge_index, edge_attr, action.edge_index) if x_i == obs[0].x.shape[0]-1: From 8cc67c15c72ddae8f18a885c3c5fc9eb9fe92d4b Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 11 Mar 2024 20:39:58 +0100 Subject: [PATCH 3/5] new file for MoG policy --- imitation/config/policy/graph_diffusion_policy.yaml | 2 +- ...aph_diffusion_policy.py => mog_ar_graph_diffusion_policy.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename imitation/policy/{ar_graph_diffusion_policy.py => mog_ar_graph_diffusion_policy.py} (100%) diff --git a/imitation/config/policy/graph_diffusion_policy.yaml b/imitation/config/policy/graph_diffusion_policy.yaml index fd40c0f..e094a22 100644 --- a/imitation/config/policy/graph_diffusion_policy.yaml +++ b/imitation/config/policy/graph_diffusion_policy.yaml @@ -1,4 +1,4 @@ -_target_: imitation.policy.ar_graph_diffusion_policy.StochAutoregressiveGraphDiffusionPolicy +_target_: imitation.policy.mog_ar_graph_diffusion_policy.StochAutoregressiveGraphDiffusionPolicy dataset: ${task.dataset} node_feature_dim: 9 # 9 when task-joint-space, 8 otherwise num_edge_types: 2 # robot joints, object-robot diff --git a/imitation/policy/ar_graph_diffusion_policy.py b/imitation/policy/mog_ar_graph_diffusion_policy.py similarity index 100% rename from imitation/policy/ar_graph_diffusion_policy.py rename to imitation/policy/mog_ar_graph_diffusion_policy.py From c7e5571a780762f5d7474f97ff8b9169be281813 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Mon, 11 Mar 2024 20:53:24 +0100 Subject: [PATCH 4/5] create mog_graph_diffusion_policy config, rename denoising network class --- ...iffusion_policy.yaml => mog_graph_diffusion_policy.yaml} | 3 ++- imitation/config/train.yaml | 2 +- imitation/model/graph_diffusion.py | 6 ++++-- imitation/policy/mog_ar_graph_diffusion_policy.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) rename imitation/config/policy/{graph_diffusion_policy.yaml => mog_graph_diffusion_policy.yaml} (83%) diff --git a/imitation/config/policy/graph_diffusion_policy.yaml b/imitation/config/policy/mog_graph_diffusion_policy.yaml similarity index 83% rename from imitation/config/policy/graph_diffusion_policy.yaml rename to imitation/config/policy/mog_graph_diffusion_policy.yaml index e094a22..1b3b1b7 100644 --- a/imitation/config/policy/graph_diffusion_policy.yaml +++ b/imitation/config/policy/mog_graph_diffusion_policy.yaml @@ -3,7 +3,7 @@ dataset: ${task.dataset} node_feature_dim: 9 # 9 when task-joint-space, 8 otherwise num_edge_types: 2 # robot joints, object-robot lr: 0.0005 -ckpt_path: ./weights/${task.task_name}_${task.dataset_type}_film_graph_diffusion_policy.pt +ckpt_path: ./weights/${task.task_name}_${task.dataset_type}_mog_graph_diffusion_policy.pt device: cuda denoising_network: _target_: imitation.model.graph_diffusion.FiLMConditionalGraphDenoisingNetwork @@ -14,3 +14,4 @@ denoising_network: num_edge_types: ${policy.num_edge_types} num_layers: 5 hidden_dim: 256 + num_mixtures: 1 \ No newline at end of file diff --git a/imitation/config/train.yaml b/imitation/config/train.yaml index 0d57bb2..a375152 100644 --- a/imitation/config/train.yaml +++ b/imitation/config/train.yaml @@ -1,7 +1,7 @@ defaults: - _self_ - task: lift_graph - - policy: graph_diffusion_policy + - policy: mog_graph_diffusion_policy output_dir: ./outputs # on evaluating, for environment wrapper diff --git a/imitation/model/graph_diffusion.py b/imitation/model/graph_diffusion.py index f6e016c..5d65cac 100644 --- a/imitation/model/graph_diffusion.py +++ b/imitation/model/graph_diffusion.py @@ -160,7 +160,7 @@ def forward(self, x, edge_index, edge_attr, v_t=None, h_v=None): return node_pred, p_e, h_v -class FiLMConditionalGraphDenoisingNetwork(nn.Module): +class MoGConditionalGraphDenoisingNetwork(nn.Module): def __init__(self, node_feature_dim, obs_horizon, @@ -169,6 +169,7 @@ def __init__(self, num_edge_types, num_layers=5, hidden_dim=256, + num_mixtures=1, device=None): ''' Denoising GNN (based on GraphARM) with FiLM conditioning on @@ -185,7 +186,8 @@ def __init__(self, self.obs_horizon = obs_horizon self.pred_horizon = pred_horizon self.hidden_dim = hidden_dim - self.num_mixtures = 2 # make sure it's power of 2 + assert num_mixtures & (num_mixtures - 1) == 0, "num_mixtures should be a power of 2" + self.num_mixtures = num_mixtures # Node embedding self.node_embedding = Linear(node_feature_dim*pred_horizon, self.hidden_dim).to(self.device) diff --git a/imitation/policy/mog_ar_graph_diffusion_policy.py b/imitation/policy/mog_ar_graph_diffusion_policy.py index d44d32c..ca610ec 100644 --- a/imitation/policy/mog_ar_graph_diffusion_policy.py +++ b/imitation/policy/mog_ar_graph_diffusion_policy.py @@ -184,7 +184,7 @@ def train(self, dataset, num_epochs=100, model_path=None, seed=0): pass self.optimizer.zero_grad() self.model.train() - batch_size = 5 + batch_size = 1 with tqdm(range(num_epochs), desc='Epoch', leave=False) as tepoch: for epoch in tepoch: From 1fcef777d34e56bcb314a26749f3aa8f1d2d5131 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Tue, 12 Mar 2024 19:03:41 +0100 Subject: [PATCH 5/5] fix variance (for debugging) --- imitation/config/eval.yaml | 4 ++-- .../policy/mog_graph_diffusion_policy.yaml | 2 +- .../policy/mog_ar_graph_diffusion_policy.py | 23 ++++++++++++------- train.py | 2 +- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/imitation/config/eval.yaml b/imitation/config/eval.yaml index 49f43cb..e09f25e 100644 --- a/imitation/config/eval.yaml +++ b/imitation/config/eval.yaml @@ -1,9 +1,9 @@ defaults: - _self_ - task: lift_graph - - policy: graph_diffusion_policy + - policy: mog_graph_diffusion_policy -render: False +render: True output_video: True num_episodes: 20 diff --git a/imitation/config/policy/mog_graph_diffusion_policy.yaml b/imitation/config/policy/mog_graph_diffusion_policy.yaml index 1b3b1b7..6f69b7b 100644 --- a/imitation/config/policy/mog_graph_diffusion_policy.yaml +++ b/imitation/config/policy/mog_graph_diffusion_policy.yaml @@ -6,7 +6,7 @@ lr: 0.0005 ckpt_path: ./weights/${task.task_name}_${task.dataset_type}_mog_graph_diffusion_policy.pt device: cuda denoising_network: - _target_: imitation.model.graph_diffusion.FiLMConditionalGraphDenoisingNetwork + _target_: imitation.model.graph_diffusion.MoGConditionalGraphDenoisingNetwork node_feature_dim: ${policy.node_feature_dim} obs_horizon: ${obs_horizon} pred_horizon: ${pred_horizon} diff --git a/imitation/policy/mog_ar_graph_diffusion_policy.py b/imitation/policy/mog_ar_graph_diffusion_policy.py index ca610ec..4c464e8 100644 --- a/imitation/policy/mog_ar_graph_diffusion_policy.py +++ b/imitation/policy/mog_ar_graph_diffusion_policy.py @@ -35,7 +35,7 @@ def __init__(self, self.global_epoch = 0 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4) - self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=50, factor=0.5, verbose=True, min_lr=lr/100) + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=200, factor=0.6, verbose=True, min_lr=lr/10) self.mode = mode if ckpt_path is not None: @@ -100,17 +100,17 @@ def node_decay_ordering(self, graph): ''' return torch.arange(graph.x.shape[0]-1, -1, -1) - def nll_loss(self, G_0, node_features, node): + def nll_loss(self, G_0, dist_params, node): n_i = G_0.x.shape[0] # get likelihood of joint values - log_likelihood = self.get_distribution_likelihood(node_features, G_0.x[node,:,:]) + log_likelihood = self.get_distribution_likelihood(dist_params, G_0.x[node,:,:]) loss = - log_likelihood return loss def get_distribution_likelihood(self, dist_params, joint_values): ''' - Returns the likelihood of joint values given a Mixture of Gaussians (MoG) distribution. + Returns the likelihood of joint values given a Gaussian Mixture Model (GMM) distribution. Args: dist_params: A tensor containing parameters for the MoG distribution. This should be in the format [means, variances, mixing_weights], @@ -120,8 +120,10 @@ def get_distribution_likelihood(self, dist_params, joint_values): ''' # Extract parameters means = dist_params[0].to(self.device).double() - variances = dist_params[1].to(self.device).double() - mixing_weights = dist_params[2].to(self.device).double() + # variances = dist_params[1].to(self.device).double() + variances = torch.zeros_like(means).to(self.device).double() + 0.01 + # mixing_weights = dist_params[2].to(self.device).double() + mixing_weights = dist_params[2].to(self.device).double() # Reshape mixing_weights for broadcasting mixing_weights = mixing_weights.unsqueeze(0).unsqueeze(0) # Add two dimensions to the front @@ -130,13 +132,18 @@ def get_distribution_likelihood(self, dist_params, joint_values): repeated_joint_values = joint_values.unsqueeze(2).repeat(1, 1, mixing_weights.shape[2]) squared_diffs = (repeated_joint_values - means) ** 2 / variances - # Calculate exponential terms and multiply with mixing weights - exp_terms = torch.exp(-0.5 * squared_diffs) * mixing_weights + # Calculate exponential terms + exp_terms = torch.exp(-0.5 * squared_diffs) # Calculate likelihoods and sum over mixtures likelihoods = exp_terms / torch.sqrt(2 * np.pi * variances) + # Multiply by mixing weights + likelihoods = likelihoods * mixing_weights likelihood_sum = torch.sum(likelihoods, dim=2) # sum over mixtures + # avoid numerical instability + likelihood_sum = torch.clamp(likelihood_sum, min=1e-20) + # Calculate and return log-likelihood return torch.sum(torch.log(likelihood_sum)) diff --git a/train.py b/train.py index 98f7815..912523a 100644 --- a/train.py +++ b/train.py @@ -38,7 +38,7 @@ def train(cfg: DictConfig) -> None: wandb.init( project=policy.__class__.__name__, group=cfg.task.task_name, - name=f"v1.0.2", + name=f"v1.0.2 - fixed variance", # track hyperparameters and run metadata config={ "policy": cfg.policy,