From 422f904eca0027c6993d9a22805d831675451163 Mon Sep 17 00:00:00 2001 From: iamaray Date: Mon, 4 Mar 2024 21:54:16 -0500 Subject: [PATCH 1/5] (probably incorrectly) returning embeddings --- matdeeplearn/models/torchmd_et.py | 123 +++++++++++++++++------------- 1 file changed, 68 insertions(+), 55 deletions(-) diff --git a/matdeeplearn/models/torchmd_et.py b/matdeeplearn/models/torchmd_et.py index a72a67d7..249a2797 100644 --- a/matdeeplearn/models/torchmd_et.py +++ b/matdeeplearn/models/torchmd_et.py @@ -16,9 +16,9 @@ from matdeeplearn.models.torchmd_output_modules import Scalar, EquivariantScalar from matdeeplearn.common.registry import registry from matdeeplearn.preprocessor.helpers import node_rep_one_hot -@registry.register_model("torchmd_et") +@registry.register_model("torchmd_et") class TorchMD_ET(BaseModel): r"""The TorchMD equivariant Transformer architecture. @@ -60,7 +60,7 @@ class TorchMD_ET(BaseModel): def __init__( self, - node_dim, + node_dim, edge_dim, output_dim, hidden_channels=128, @@ -110,7 +110,8 @@ def __init__( self.distance_influence = distance_influence self.max_z = max_z self.pool = pool - assert pool_order in ['early', 'late'], f"{pool_order} is currently not supported" + assert pool_order in [ + 'early', 'late'], f"{pool_order} is currently not supported" self.pool_order = pool_order self.output_dim = output_dim cutoff_lower = 0 @@ -159,10 +160,13 @@ def __init__( self.post_lin_list = nn.ModuleList() for i in range(self.num_post_layers): if i == 0: - self.post_lin_list.append(nn.Linear(hidden_channels, post_hidden_channels)) + self.post_lin_list.append( + nn.Linear(hidden_channels, post_hidden_channels)) else: - self.post_lin_list.append(nn.Linear(post_hidden_channels, post_hidden_channels)) - self.post_lin_list.append(nn.Linear(post_hidden_channels, self.output_dim)) + self.post_lin_list.append( + nn.Linear(post_hidden_channels, post_hidden_channels)) + self.post_lin_list.append( + nn.Linear(post_hidden_channels, self.output_dim)) self.reset_parameters() @@ -174,39 +178,43 @@ def reset_parameters(self): for attn in self.attention_layers: attn.reset_parameters() self.out_norm.reset_parameters() - + @conditional_grad(torch.enable_grad()) def _forward(self, data): x = self.embedding(data.z) - #edge_index, edge_weight, edge_vec = self.distance(data.pos, data.batch) - #assert ( + # edge_index, edge_weight, edge_vec = self.distance(data.pos, data.batch) + # assert ( # edge_vec is not None - #), "Distance module did not return directional information" + # ), "Distance module did not return directional information" if self.otf_edge_index == True: - #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) - data.edge_index, data.edge_weight, data.edge_vec, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) - data.edge_attr = self.distance_expansion(data.edge_weight) - - #mask = data.edge_index[0] != data.edge_index[1] - #data.edge_vec[mask] = data.edge_vec[mask] / torch.norm(data.edge_vec[mask], dim=1).unsqueeze(1) - data.edge_vec = data.edge_vec / torch.norm(data.edge_vec, dim=1).unsqueeze(1) - + # data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + data.edge_index, data.edge_weight, data.edge_vec, _, _, _ = self.generate_graph( + data, self.cutoff_radius, self.n_neighbors) + data.edge_attr = self.distance_expansion(data.edge_weight) + + # mask = data.edge_index[0] != data.edge_index[1] + # data.edge_vec[mask] = data.edge_vec[mask] / torch.norm(data.edge_vec[mask], dim=1).unsqueeze(1) + data.edge_vec = data.edge_vec / \ + torch.norm(data.edge_vec, dim=1).unsqueeze(1) + if self.otf_node_attr == True: - data.x = node_rep_one_hot(data.z).float() - + data.x = node_rep_one_hot(data.z).float() + if self.neighbor_embedding is not None: - x = self.neighbor_embedding(data.z, x, data.edge_index, data.edge_weight, data.edge_attr) + x = self.neighbor_embedding( + data.z, x, data.edge_index, data.edge_weight, data.edge_attr) vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) for attn in self.attention_layers: - dx, dvec = attn(x, vec, data.edge_index, data.edge_weight, data.edge_attr, data.edge_vec) + dx, dvec = attn(x, vec, data.edge_index, + data.edge_weight, data.edge_attr, data.edge_vec) x = x + dx vec = vec + dvec x = self.out_norm(x) - + if self.prediction_level == "graph": if self.pool_order == 'early': x = getattr(torch_geometric.nn, self.pool)(x, data.batch) @@ -216,40 +224,40 @@ def _forward(self, data): x = self.post_lin_list[-1](x) if self.pool_order == 'late': x = getattr(torch_geometric.nn, self.pool)(x, data.batch) - #x = self.pool.pre_reduce(x, vec, data.z, data.pos, data.batch) - #x = self.pool.reduce(x, data.batch) + # x = self.pool.pre_reduce(x, vec, data.z, data.pos, data.batch) + # x = self.pool.reduce(x, data.batch) elif self.prediction_level == "node": for i in range(0, len(self.post_lin_list) - 1): x = self.post_lin_list[i](x) x = getattr(F, self.activation)(x) - x = self.post_lin_list[-1](x) - + x = self.post_lin_list[-1](x) + return x - + def forward(self, data): - + output = {} out = self._forward(data) - output["output"] = out - - if self.gradient == True and out.requires_grad == True: - volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) - grad = torch.autograd.grad( - out, - [data.pos, data.displacement], - grad_outputs=torch.ones_like(out), - create_graph=self.training) - forces = -1 * grad[0] - stress = grad[1] - stress = stress / volume.view(-1, 1, 1) - - output["pos_grad"] = forces - output["cell_grad"] = stress - else: - output["pos_grad"] = None - output["cell_grad"] = None - - return output + output["output"] = out + + # if self.gradient == True and out.requires_grad == True: + # volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + # grad = torch.autograd.grad( + # out, + # [data.pos, data.displacement], + # grad_outputs=torch.ones_like(out), + # create_graph=self.training) + # forces = -1 * grad[0] + # stress = grad[1] + # stress = stress / volume.view(-1, 1, 1) + + # output["pos_grad"] = forces + # output["cell_grad"] = stress + # else: + # output["pos_grad"] = None + # output["cell_grad"] = None + + return output def __repr__(self): return ( @@ -267,6 +275,7 @@ def __repr__(self): f"cutoff_lower={self.cutoff_lower}, " f"self.cutoff_radius={self.self.cutoff_radius})" ) + @property def target_attr(self): return "y" @@ -285,7 +294,8 @@ def __init__( cutoff_upper, aggregation, ): - super(EquivariantMultiHeadAttention, self).__init__(aggr=aggregation, node_dim=0) + super(EquivariantMultiHeadAttention, self).__init__( + aggr=aggregation, node_dim=0) assert hidden_channels % num_heads == 0, ( f"The number of hidden channels ({hidden_channels}) " f"must be evenly divisible by the number of " @@ -307,7 +317,8 @@ def __init__( self.v_proj = nn.Linear(hidden_channels, hidden_channels * 3) self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3) - self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False) + self.vec_proj = nn.Linear( + hidden_channels, hidden_channels * 3, bias=False) self.dk_proj = None if distance_influence in ["keys", "both"]: @@ -343,21 +354,23 @@ def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij): k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim) v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim * 3) - vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1) + vec1, vec2, vec3 = torch.split( + self.vec_proj(vec), self.hidden_channels, dim=-1) vec = vec.reshape(-1, 3, self.num_heads, self.head_dim) vec_dot = (vec1 * vec2).sum(dim=1) dk = ( - self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) + self.act(self.dk_proj(f_ij)).reshape(-1, + self.num_heads, self.head_dim) if self.dk_proj is not None else None ) dv = ( - self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim * 3) + self.act(self.dv_proj(f_ij)).reshape(-1, + self.num_heads, self.head_dim * 3) if self.dv_proj is not None else None ) - # propagate_type: (q: Tensor, k: Tensor, v: Tensor, vec: Tensor, dk: Tensor, dv: Tensor, r_ij: Tensor, d_ij: Tensor) x, vec = self.propagate( From df261754c04943f7c1cd94844dba32a8a2749717 Mon Sep 17 00:00:00 2001 From: iamaray Date: Thu, 7 Mar 2024 19:19:59 -0500 Subject: [PATCH 2/5] Outputting embeddings I think --- matdeeplearn/models/torchmd_et.py | 68 ++++++++++++++++--------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/matdeeplearn/models/torchmd_et.py b/matdeeplearn/models/torchmd_et.py index 249a2797..0620e0a5 100644 --- a/matdeeplearn/models/torchmd_et.py +++ b/matdeeplearn/models/torchmd_et.py @@ -213,24 +213,25 @@ def _forward(self, data): data.edge_weight, data.edge_attr, data.edge_vec) x = x + dx vec = vec + dvec + # just output the embeddings => stop before the prediction layer x = self.out_norm(x) - if self.prediction_level == "graph": - if self.pool_order == 'early': - x = getattr(torch_geometric.nn, self.pool)(x, data.batch) - for i in range(0, len(self.post_lin_list) - 1): - x = self.post_lin_list[i](x) - x = getattr(F, self.activation)(x) - x = self.post_lin_list[-1](x) - if self.pool_order == 'late': - x = getattr(torch_geometric.nn, self.pool)(x, data.batch) - # x = self.pool.pre_reduce(x, vec, data.z, data.pos, data.batch) - # x = self.pool.reduce(x, data.batch) - elif self.prediction_level == "node": - for i in range(0, len(self.post_lin_list) - 1): - x = self.post_lin_list[i](x) - x = getattr(F, self.activation)(x) - x = self.post_lin_list[-1](x) + # if self.prediction_level == "graph": + # if self.pool_order == 'early': + # x = getattr(torch_geometric.nn, self.pool)(x, data.batch) + # for i in range(0, len(self.post_lin_list) - 1): + # x = self.post_lin_list[i](x) + # x = getattr(F, self.activation)(x) + # x = self.post_lin_list[-1](x) + # if self.pool_order == 'late': + # x = getattr(torch_geometric.nn, self.pool)(x, data.batch) + # # x = self.pool.pre_reduce(x, vec, data.z, data.pos, data.batch) + # # x = self.pool.reduce(x, data.batch) + # elif self.prediction_level == "node": + # for i in range(0, len(self.post_lin_list) - 1): + # x = self.post_lin_list[i](x) + # x = getattr(F, self.activation)(x) + # x = self.post_lin_list[-1](x) return x @@ -240,22 +241,25 @@ def forward(self, data): out = self._forward(data) output["output"] = out - # if self.gradient == True and out.requires_grad == True: - # volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) - # grad = torch.autograd.grad( - # out, - # [data.pos, data.displacement], - # grad_outputs=torch.ones_like(out), - # create_graph=self.training) - # forces = -1 * grad[0] - # stress = grad[1] - # stress = stress / volume.view(-1, 1, 1) - - # output["pos_grad"] = forces - # output["cell_grad"] = stress - # else: - # output["pos_grad"] = None - # output["cell_grad"] = None + # this is skipped reached since we're not getting the prediction (I think?) + # even if it is reached, we're probably fine lol. + if self.gradient == True and out.requires_grad == True: + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross( + data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.displacement], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + + output["pos_grad"] = forces + output["cell_grad"] = stress + else: + output["pos_grad"] = None + output["cell_grad"] = None return output From 178d57262ed4a2c58b40328704bb88f515f27e12 Mon Sep 17 00:00:00 2001 From: iamaray Date: Thu, 7 Mar 2024 19:21:20 -0500 Subject: [PATCH 3/5] added todo comment --- matdeeplearn/models/torchmd_et.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/matdeeplearn/models/torchmd_et.py b/matdeeplearn/models/torchmd_et.py index 0620e0a5..e016bfac 100644 --- a/matdeeplearn/models/torchmd_et.py +++ b/matdeeplearn/models/torchmd_et.py @@ -233,6 +233,8 @@ def _forward(self, data): # x = getattr(F, self.activation)(x) # x = self.post_lin_list[-1](x) + # TODO: FIGURE OUT HOW TO ACCESS EMBEDDINGS; WE NEED THEM TO COMPUTE + # MOLECULAR FINGERPRINTS. return x def forward(self, data): From fc7d92d383bc9e457dd474f007b6edae2faa51e0 Mon Sep 17 00:00:00 2001 From: iamaray Date: Mon, 11 Mar 2024 16:25:20 -0400 Subject: [PATCH 4/5] readout operation applied to atomic embeddings --- matdeeplearn/models/torchmd_et.py | 1 + matdeeplearn/trainers/property_trainer.py | 331 ++++++++++++---------- 2 files changed, 188 insertions(+), 144 deletions(-) diff --git a/matdeeplearn/models/torchmd_et.py b/matdeeplearn/models/torchmd_et.py index e016bfac..a33121d6 100644 --- a/matdeeplearn/models/torchmd_et.py +++ b/matdeeplearn/models/torchmd_et.py @@ -235,6 +235,7 @@ def _forward(self, data): # TODO: FIGURE OUT HOW TO ACCESS EMBEDDINGS; WE NEED THEM TO COMPUTE # MOLECULAR FINGERPRINTS. + return x def forward(self, data): diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index 06916021..3900ab49 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -55,7 +55,7 @@ def __init__( output_frequency, model_save_frequency, save_dir, - checkpoint_path, + checkpoint_path, use_amp, ) @@ -68,7 +68,7 @@ def train(self): if str(self.rank) not in ("cpu", "cuda"): dist.barrier() - + end_epoch = ( self.max_checkpoint_epochs + start_epoch if self.max_checkpoint_epochs @@ -85,63 +85,75 @@ def train(self): logging.info( f"Running for {end_epoch - start_epoch} epochs on {type(self.model[0]).__name__} model" ) - - for epoch in range(start_epoch, end_epoch): + + for epoch in range(start_epoch, end_epoch): epoch_start_time = time.time() if self.train_sampler: self.train_sampler.set_epoch(epoch) # skip_steps = self.step % len(self.train_loader) train_loader_iter = [] for i in range(len(self.model)): - train_loader_iter.append(iter(self.data_loader[i]["train_loader"])) + train_loader_iter.append( + iter(self.data_loader[i]["train_loader"])) # metrics for every epoch _metrics = [{} for _ in range(len(self.model))] - - #for i in range(skip_steps, len(self.train_loader)): - pbar = tqdm(range(0, len(self.data_loader[0]["train_loader"])), disable=not self.batch_tqdm) - for i in pbar: - #self.epoch = epoch + (i + 1) / len(self.train_loader) - #self.step = epoch * len(self.train_loader) + i + 1 - #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) + + # for i in range(skip_steps, len(self.train_loader)): + pbar = tqdm( + range(0, len(self.data_loader[0]["train_loader"])), disable=not self.batch_tqdm) + for i in pbar: + # self.epoch = epoch + (i + 1) / len(self.train_loader) + # self.step = epoch * len(self.train_loader) + i + 1 + # print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) batch = [] for n, mod in enumerate(self.model): mod.train() batch.append(next(train_loader_iter[n]).to(self.rank)) # Get a batch of train data - # batch = next(train_loader_iter).to(self.rank) - # print(epoch, i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024), torch.sum(batch.n_atoms)) - # Compute forward, loss, backward + # batch = next(train_loader_iter).to(self.rank) + # print(epoch, i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024), torch.sum(batch.n_atoms)) + # Compute forward, loss, backward with autocast(enabled=self.use_amp): - out_list = self._forward(batch) - loss = self._compute_loss(out_list, batch) - #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) + out_list = self._forward(batch) + out = out_list[0] + # Perform a readout operation on the atomic node embeddings + # to obtain a representation of the entire molecule. + # TODO: We need to extract this vector somehow + readout = torch.exp(torch.mean(torch.log(out), dim=1)) + loss = self._compute_loss(out_list, batch) + + # print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) grad_norm = [] for i in range(len(self.model)): grad_norm.append(self._backward(loss[i], i)) - pbar.set_description("Batch Loss {:.4f}, grad norm {:.4f}".format(torch.mean(torch.stack(loss)).item(), torch.mean(torch.stack(grad_norm)).item())) + pbar.set_description("Batch Loss {:.4f}, grad norm {:.4f}".format(torch.mean( + torch.stack(loss)).item(), torch.mean(torch.stack(grad_norm)).item())) # Compute metrics # TODO: revert _metrics to be empty per batch, so metrics are logged per batch, not per epoch - # keep option to log metrics per epoch + # keep option to log metrics per epoch for n in range(len(self.model)): - _metrics[n] = self._compute_metrics(out_list[n], batch[n], _metrics[n]) - self.metrics[n] = self.evaluator.update("loss", loss[n].item(), out_list[n]["output"].shape[0], _metrics[n]) + _metrics[n] = self._compute_metrics( + out_list[n], batch[n], _metrics[n]) + self.metrics[n] = self.evaluator.update( + "loss", loss[n].item(), out_list[n]["output"].shape[0], _metrics[n]) self.epoch = epoch + 1 if str(self.rank) not in ("cpu", "cuda"): dist.barrier() - # TODO: could add param to eval and save on increments instead of every time - - # Save current model - torch.cuda.empty_cache() + # TODO: could add param to eval and save on increments instead of every time + + # Save current model + torch.cuda.empty_cache() if str(self.rank) in ("0", "cpu", "cuda"): if self.model_save_frequency == 1: - self.save_model(checkpoint_file="checkpoint.pt", training_state=True) + self.save_model( + checkpoint_file="checkpoint.pt", training_state=True) # Evaluate on validation set if it exists if self.data_loader[0].get("val_loader"): - metric = self.validate("val") + metric = self.validate("val") else: metric = self.metrics @@ -159,48 +171,53 @@ def train(self): if metric[i][type(self.loss_fn).__name__]["metric"] < self.best_metric[i]: if self.output_frequency == 0: if self.model_save_frequency == 1: - self.update_best_model(metric[i], i, write_model=True, write_csv=False) + self.update_best_model( + metric[i], i, write_model=True, write_csv=False) else: - self.update_best_model(metric[i], i, write_model=False, write_csv=False) + self.update_best_model( + metric[i], i, write_model=False, write_csv=False) elif self.output_frequency == 1: if self.model_save_frequency == 1: - self.update_best_model(metric[i], i, write_model=True, write_csv=True) + self.update_best_model( + metric[i], i, write_model=True, write_csv=True) else: - self.update_best_model(metric[i], i, write_model=False, write_csv=True) - + self.update_best_model( + metric[i], i, write_model=False, write_csv=True) + self._scheduler_step() - - torch.cuda.empty_cache() - + torch.cuda.empty_cache() + if self.best_model_state: for i in range(len(self.model)): if str(self.rank) in "0": - self.model[i].module.load_state_dict(self.best_model_state[i]) + self.model[i].module.load_state_dict( + self.best_model_state[i]) elif str(self.rank) in ("cpu", "cuda"): self.model[i].load_state_dict(self.best_model_state[i]) - #if self.data_loader.get("test_loader"): + # if self.data_loader.get("test_loader"): # metric = self.validate("test") # test_loss = metric[type(self.loss_fn).__name__]["metric"] - #else: - # test_loss = "N/A" + # else: + # test_loss = "N/A" if self.model_save_frequency != -1: - self.save_model("best_checkpoint.pt", index=None, metric=metric, training_state=True) - logging.info("Final Losses: ") + self.save_model("best_checkpoint.pt", index=None, + metric=metric, training_state=True) + logging.info("Final Losses: ") if "train" in self.write_output: self.predict(self.data_loader[0]["train_loader"], "train") if "val" in self.write_output and self.data_loader[0].get("val_loader"): self.predict(self.data_loader[0]["val_loader"], "val") if "test" in self.write_output and self.data_loader[0].get("test_loader"): - self.predict(self.data_loader[0]["test_loader"], "test") - + self.predict(self.data_loader[0]["test_loader"], "test") + return self.best_model_state - + @torch.no_grad() def validate(self, split="val"): for i in range(len(self.model)): self.model[i].eval() - + evaluator, metrics = Evaluator(), [{} for _ in range(len(self.model))] loader_iter = [] @@ -211,31 +228,33 @@ def validate(self, split="val"): loader_iter.append(iter(self.data_loader[i]["test_loader"])) elif split == "train": loader_iter.append(iter(self.data_loader[i]["train_loader"])) - + for i in range(0, len(loader_iter[0])): - #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) + # print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) batch = [] for i in range(len(self.model)): batch.append(next(loader_iter[i]).to(self.rank)) - + out_list = self._forward(batch) loss = self._compute_loss(out_list, batch) # Compute metrics - #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) + # print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) for n in range(len(self.model)): - metrics[n] = self._compute_metrics(out_list[n], batch[n], metrics[n]) - metrics[n] = evaluator.update("loss", loss[n].item(), out_list[n]["output"].shape[0], metrics[n]) + metrics[n] = self._compute_metrics( + out_list[n], batch[n], metrics[n]) + metrics[n] = evaluator.update( + "loss", loss[n].item(), out_list[n]["output"].shape[0], metrics[n]) del loss, batch, out_list - + torch.cuda.empty_cache() - + return metrics @torch.no_grad() - def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True): + def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True): for mod in self.model: mod.eval() - + # assert isinstance(loader, torch.utils.data.dataloader.DataLoader) # TODO: make this compatible with model ensemble @@ -243,7 +262,7 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, loader = get_dataloader( loader.dataset, batch_size=loader.batch_size, sampler=None ) - + evaluator, metrics = Evaluator(), {} predict, target = None, None ids = [] @@ -251,15 +270,15 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, target_pos_grad = None ids_cell_grad = [] target_cell_grad = None - node_level = False - - loader_iter = iter(loader) + node_level = False + + loader_iter = iter(loader) for i in range(0, len(loader_iter)): batch = next(loader_iter).to(self.rank) out_list = self._forward([batch]) - + out = {} - out_stack={} + out_stack = {} for key in out_list[0].keys(): temp = [o[key] for o in out_list] if temp[0] is not None: @@ -269,12 +288,11 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, else: out[key] = None out[key+"_std"] = None - - + batch_p = [o["output"].data.cpu().numpy() for o in out_list] - batch_p_mean = out["output"].cpu().numpy() + batch_p_mean = out["output"].cpu().numpy() batch_ids = batch.structure_id - batch_stds = out["output_std"].cpu().numpy() + batch_stds = out["output_std"].cpu().numpy() if labels == True: loss = self._compute_loss(out, batch) @@ -283,12 +301,13 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, "loss", loss.item(), out["output"].shape[0], metrics ) if str(self.rank) not in ("cpu", "cuda"): - batch_t = batch[self.model[0].module.target_attr].cpu().numpy() + batch_t = batch[self.model[0].module.target_attr].cpu( + ).numpy() else: batch_t = batch[self.model[0].target_attr].cpu().numpy() - - # Node level prediction - if batch_p[0].shape[0] > loader.batch_size: + + # Node level prediction + if batch_p[0].shape[0] > loader.batch_size: node_level = True node_ids = batch.z.cpu().numpy() structure_ids = np.repeat( @@ -303,137 +322,157 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, structure_ids_pos_grad = np.repeat( batch.structure_id, batch.n_atoms.cpu().numpy(), axis=0 ) - batch_ids_pos_grad = np.column_stack((structure_ids_pos_grad, node_ids_pos_grad)) - ids_pos_grad = batch_ids_pos_grad if i == 0 else np.row_stack((ids_pos_grad, batch_ids_pos_grad)) - predict_pos_grad = batch_p_pos_grad if i == 0 else np.concatenate((predict_pos_grad, batch_p_pos_grad), axis=0) - predict_pos_grad_std = batch_p_pos_grad_std if i == 0 else np.concatenate((predict_pos_grad_std, batch_p_pos_grad_std), axis=0) + batch_ids_pos_grad = np.column_stack( + (structure_ids_pos_grad, node_ids_pos_grad)) + ids_pos_grad = batch_ids_pos_grad if i == 0 else np.row_stack( + (ids_pos_grad, batch_ids_pos_grad)) + predict_pos_grad = batch_p_pos_grad if i == 0 else np.concatenate( + (predict_pos_grad, batch_p_pos_grad), axis=0) + predict_pos_grad_std = batch_p_pos_grad_std if i == 0 else np.concatenate( + (predict_pos_grad_std, batch_p_pos_grad_std), axis=0) if "forces" in batch: - batch_t_pos_grad = batch["forces"].cpu().numpy() - target_pos_grad = batch_t_pos_grad if i == 0 else np.concatenate((target_pos_grad, batch_t_pos_grad), axis=0) - - if out.get("cell_grad") != None: - batch_p_cell_grad = out["cell_grad"].data.view(out["cell_grad"].data.size(0), -1).cpu().numpy() - batch_p_cell_grad_std = out["cell_grad_std"].data.view(out["cell_grad"].data.size(0), -1).cpu().numpy() - batch_ids_cell_grad = batch.structure_id - ids_cell_grad = batch_ids_cell_grad if i == 0 else np.row_stack((ids_cell_grad, batch_ids_cell_grad)) - predict_cell_grad = batch_p_cell_grad if i == 0 else np.concatenate((predict_cell_grad, batch_p_cell_grad), axis=0) - predict_cell_grad_std = batch_p_cell_grad_std if i == 0 else np.concatenate((predict_cell_grad_std, batch_p_cell_grad_std), axis=0) + batch_t_pos_grad = batch["forces"].cpu().numpy() + target_pos_grad = batch_t_pos_grad if i == 0 else np.concatenate( + (target_pos_grad, batch_t_pos_grad), axis=0) + + if out.get("cell_grad") != None: + batch_p_cell_grad = out["cell_grad"].data.view( + out["cell_grad"].data.size(0), -1).cpu().numpy() + batch_p_cell_grad_std = out["cell_grad_std"].data.view( + out["cell_grad"].data.size(0), -1).cpu().numpy() + batch_ids_cell_grad = batch.structure_id + ids_cell_grad = batch_ids_cell_grad if i == 0 else np.row_stack( + (ids_cell_grad, batch_ids_cell_grad)) + predict_cell_grad = batch_p_cell_grad if i == 0 else np.concatenate( + (predict_cell_grad, batch_p_cell_grad), axis=0) + predict_cell_grad_std = batch_p_cell_grad_std if i == 0 else np.concatenate( + (predict_cell_grad_std, batch_p_cell_grad_std), axis=0) if "stress" in batch: - batch_t_cell_grad = batch["stress"].view(out["cell_grad"].data.size(0), -1).cpu().numpy() - target_cell_grad = batch_t_cell_grad if i == 0 else np.concatenate((target_cell_grad, batch_t_cell_grad), axis=0) - - ids = batch_ids if i == 0 else np.row_stack((ids, batch_ids)) - predict_mean = batch_p_mean if i == 0 else np.concatenate((predict_mean, batch_p_mean), axis=0) - stds = batch_stds if i == 0 else np.row_stack((stds, batch_stds)) - if i == 0: - predict = [0 for _ in range(len(self.model))] + batch_t_cell_grad = batch["stress"].view( + out["cell_grad"].data.size(0), -1).cpu().numpy() + target_cell_grad = batch_t_cell_grad if i == 0 else np.concatenate( + (target_cell_grad, batch_t_cell_grad), axis=0) + + ids = batch_ids if i == 0 else np.row_stack((ids, batch_ids)) + predict_mean = batch_p_mean if i == 0 else np.concatenate( + (predict_mean, batch_p_mean), axis=0) + stds = batch_stds if i == 0 else np.row_stack((stds, batch_stds)) + if i == 0: + predict = [0 for _ in range(len(self.model))] for x in range(len(self.model)): - predict[x] = batch_p[x] if i == 0 else np.concatenate((predict[x], batch_p[x]), axis=0) + predict[x] = batch_p[x] if i == 0 else np.concatenate( + (predict[x], batch_p[x]), axis=0) if labels == True: - target = batch_t if i == 0 else np.concatenate((target, batch_t), axis=0) - + target = batch_t if i == 0 else np.concatenate( + (target, batch_t), axis=0) + if labels == True: - del loss, batch, out - else: - del batch, out - + del loss, batch, out + else: + del batch, out + if write_output == True: if labels == True: - if len(self.model) > 1: + if len(self.model) > 1: self.save_results( np.column_stack((ids, target, predict_mean, stds)), results_dir, f"{split}_predictions.csv", node_level, True, std=True, - ) + ) for x in range(len(self.model)): - mod = str(x) + mod = str(x) self.save_results( np.column_stack((ids, target, predict[x])), results_dir, f"{split}_predictions_{mod}.csv", node_level, True, std=False, ) - else: + else: self.save_results( np.column_stack((ids, target, predict_mean)), results_dir, f"{split}_predictions.csv", node_level, True, std=False, - ) + ) else: - if len(self.model) > 1: + if len(self.model) > 1: self.save_results( np.column_stack((ids, predict_mean, stds)), results_dir, f"{split}_predictions.csv", node_level, False, std=True, - ) + ) for x in range(len(self.model)): mod = str(x) self.save_results( - np.column_stack((ids, predict[x])), results_dir, f"{split}_predictions_{mod}.csv", node_level, False, std=False, + np.column_stack((ids, predict[x])), results_dir, f"{split}_predictions_{mod}.csv", node_level, False, std=False, ) - else: + else: self.save_results( np.column_stack((ids, predict_mean)), results_dir, f"{split}_predictions.csv", node_level, False, std=False, - ) - #if out.get("pos_grad") != None: + ) + # if out.get("pos_grad") != None: if len(ids_pos_grad) > 0: if isinstance(target_pos_grad, np.ndarray): - if len(self.model) > 1: + if len(self.model) > 1: self.save_results( np.column_stack((ids_pos_grad, target_pos_grad, predict_pos_grad, predict_pos_grad_std)), results_dir, f"{split}_predictions_pos_grad.csv", True, True, std=True ) else: self.save_results( np.column_stack((ids_pos_grad, target_pos_grad, predict_pos_grad)), results_dir, f"{split}_predictions_pos_grad.csv", True, True, std=False - ) + ) else: self.save_results( np.column_stack((ids_pos_grad, predict_pos_grad)), results_dir, f"{split}_predictions_pos_grad.csv", True, False, std=False ) - #if out.get("cell_grad") != None: + # if out.get("cell_grad") != None: if len(ids_cell_grad) > 0: if isinstance(target_cell_grad, np.ndarray): - if len(self.model) > 1: + if len(self.model) > 1: self.save_results( np.column_stack((ids_cell_grad, target_cell_grad, predict_cell_grad, predict_cell_grad_std)), results_dir, f"{split}_predictions_cell_grad.csv", False, True, std=True ) else: self.save_results( np.column_stack((ids_cell_grad, target_cell_grad, predict_cell_grad)), results_dir, f"{split}_predictions_cell_grad.csv", False, True, std=False - ) + ) else: self.save_results( np.column_stack((ids_cell_grad, predict_cell_grad)), results_dir, f"{split}_predictions_cell_grad.csv", False, False, std=False ) - + if labels == True: predict_loss = metrics[type(self.loss_fn).__name__]["metric"] - logging.info("Saved {:s} error: {:.5f}".format(split, predict_loss)) - if len(self.model) > 1: - predictions = {"ids":ids, "predict":predict_mean, "target":target, "std": stds} + logging.info("Saved {:s} error: {:.5f}".format( + split, predict_loss)) + if len(self.model) > 1: + predictions = {"ids": ids, "predict": predict_mean, + "target": target, "std": stds} else: - predictions = {"ids":ids, "predict":predict_mean, "target":target} + predictions = {"ids": ids, + "predict": predict_mean, "target": target} else: - if len(self.model) > 1: - predictions = {"ids":ids, "predict":predict_mean, "std": stds} + if len(self.model) > 1: + predictions = {"ids": ids, + "predict": predict_mean, "std": stds} else: - predictions = {"ids":ids, "predict":predict_mean} + predictions = {"ids": ids, "predict": predict_mean} torch.cuda.empty_cache() - + return predictions - - def predict_by_calculator(self, loader): + + def predict_by_calculator(self, loader): for x, mod in self.model: mod.eval() - + assert isinstance(loader, torch.utils.data.dataloader.DataLoader) - assert len(loader) == 1, f"Predicting by calculator only allows one structure at a time, but got {len(loader)} structures." + assert len( + loader) == 1, f"Predicting by calculator only allows one structure at a time, but got {len(loader)} structures." if str(self.rank) not in ("cpu", "cuda"): loader = get_dataloader( loader.dataset, batch_size=loader.batch_size, sampler=None ) - + results = [] loader_iter = iter(loader) for i in range(0, len(loader_iter)): - batch = next(loader_iter).to(self.rank) + batch = next(loader_iter).to(self.rank) out_list = self._forward(batch.to(self.rank)) out = {} - out_stack={} + out_stack = {} for key in out_list[0].keys(): temp = [o[key] for o in out_list] if temp[0] is not None: @@ -442,12 +481,15 @@ def predict_by_calculator(self, loader): else: out[key] = None - energy = None if out.get('output') is None else out.get('output').data.cpu().numpy() - stress = None if out.get('cell_grad') is None else out.get('cell_grad').view(-1, 3).data.cpu().numpy() - forces = None if out.get('pos_grad') is None else out.get('pos_grad').data.cpu().numpy() - + energy = None if out.get('output') is None else out.get( + 'output').data.cpu().numpy() + stress = None if out.get('cell_grad') is None else out.get( + 'cell_grad').view(-1, 3).data.cpu().numpy() + forces = None if out.get('pos_grad') is None else out.get( + 'pos_grad').data.cpu().numpy() + results = {'energy': energy, 'stress': stress, 'forces': forces} - + return results def _forward(self, batch_data): @@ -480,9 +522,8 @@ def _backward(self, loss, index=None): ) self.scaler.step(self.optimizer[index]) self.scaler.update() - - return grad_norm + return grad_norm def _compute_metrics(self, out, batch_data, metrics): # TODO: finish this method @@ -493,14 +534,15 @@ def _compute_metrics(self, out, batch_data, metrics): metrics = self.evaluator.eval( out, property_target, self.loss_fn, prev_metrics=metrics - ) + ) return metrics def _log_metrics(self, val_metrics=None): - train_loss = [torch.tensor(i[type(self.loss_fn).__name__]["metric"]) for i in self.metrics] + train_loss = [torch.tensor( + i[type(self.loss_fn).__name__]["metric"]) for i in self.metrics] train_loss = torch.mean(torch.stack(train_loss)).item() - lr = self.scheduler[0].lr + lr = self.scheduler[0].lr if not val_metrics: val_loss = "N/A" logging.info( @@ -513,7 +555,8 @@ def _log_metrics(self, val_metrics=None): ) ) else: - val_loss = [torch.tensor(i[type(self.loss_fn).__name__]["metric"]) for i in val_metrics] + val_loss = [torch.tensor( + i[type(self.loss_fn).__name__]["metric"]) for i in val_metrics] val_loss = torch.mean(torch.stack(val_loss)).item() lr = self.scheduler[0].lr logging.info( @@ -526,7 +569,6 @@ def _log_metrics(self, val_metrics=None): ) ) - def _load_task(self): """Initializes task-specific info. Implemented by derived classes.""" pass @@ -535,7 +577,8 @@ def _scheduler_step(self): for i in range(len(self.model)): if self.scheduler[i].scheduler_type == "ReduceLROnPlateau": self.scheduler[i].step( - metrics=self.metrics[i][type(self.loss_fn).__name__]["metric"] + metrics=self.metrics[i][type( + self.loss_fn).__name__]["metric"] ) else: self.scheduler[i].step() From 37666a9e45725e1761e90cecd5826246607fda04 Mon Sep 17 00:00:00 2001 From: iamaray Date: Fri, 29 Mar 2024 14:36:30 -0400 Subject: [PATCH 5/5] made an oopsy; fixed it --- matdeeplearn/trainers/property_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index 3900ab49..831dad08 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -118,6 +118,7 @@ def train(self): out = out_list[0] # Perform a readout operation on the atomic node embeddings # to obtain a representation of the entire molecule. + # TODO: We need to extract this vector somehow readout = torch.exp(torch.mean(torch.log(out), dim=1)) loss = self._compute_loss(out_list, batch)