From b3895b3c10dd93d15238971a64801440828efa95 Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Mon, 12 Feb 2024 22:34:13 -0500 Subject: [PATCH 1/3] Add vmap_pred option to predict --- matdeeplearn/trainers/property_trainer.py | 51 ++++++++++++++--------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index 06916021..509b27d3 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -1,5 +1,6 @@ import logging import time +import copy import numpy as np import torch @@ -232,9 +233,16 @@ def validate(self, split="val"): return metrics @torch.no_grad() - def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True): + def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True, vmap_pred = False): for mod in self.model: mod.eval() + if vmap_pred: + params, buffers = stack_module_state(self.model) + base_model = copy.deepcopy(self.model[0]) + base_model = base_model.to('meta') + # TODO: Allow to work with pos_grad and cell_grad + def fmodel(params, buffers, x): + return functional_call(base_model, (params, buffers), (x,))['output'] # assert isinstance(loader, torch.utils.data.dataloader.DataLoader) @@ -256,25 +264,30 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, loader_iter = iter(loader) for i in range(0, len(loader_iter)): batch = next(loader_iter).to(self.rank) - out_list = self._forward([batch]) - - out = {} - out_stack={} - for key in out_list[0].keys(): - temp = [o[key] for o in out_list] - if temp[0] is not None: - out_stack[key] = torch.stack(temp) - out[key] = torch.mean(out_stack[key], dim=0) - out[key+"_std"] = torch.std(out_stack[key], dim=0) - else: - out[key] = None - out[key+"_std"] = None - - - batch_p = [o["output"].data.cpu().numpy() for o in out_list] - batch_p_mean = out["output"].cpu().numpy() + out = {} + out_stack={} + if not vmap_pred: + out_list = self._forward([batch]) + for key in out_list[0].keys(): + temp = [o[key] for o in out_list] + if temp[0] is not None: + out_stack[key] = torch.stack(temp) + out[key] = torch.mean(out_stack[key], dim=0) + out[key+"_std"] = torch.std(out_stack[key], dim=0) + else: + out[key] = None + out[key+"_std"] = None + batch_p = [o["output"].data.cpu().numpy() for o in out_list] + + else: + out_list = vmap(fmodel, in_dims = (0, 0, None))(self.params, self.buffers, batch) + out["output"] = torch.mean(out_list, dim = 0) + out["output_std"] = torch.std(out_list, dim = 0) + batch_p = [out_list[o].cpu().numpy() for o in range(out_list.size()[0])] + + batch_p_mean = out["output"].cpu().numpy() + batch_stds = out["output_std"].cpu().numpy() batch_ids = batch.structure_id - batch_stds = out["output_std"].cpu().numpy() if labels == True: loss = self._compute_loss(out, batch) From 149e7bfb5f12d8f975d4336d336b07d0ee7bc134 Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Mon, 12 Feb 2024 22:36:57 -0500 Subject: [PATCH 2/3] Pass in vmap_pred as a config argument --- matdeeplearn/tasks/task.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matdeeplearn/tasks/task.py b/matdeeplearn/tasks/task.py index 6a48172e..16ea18f5 100644 --- a/matdeeplearn/tasks/task.py +++ b/matdeeplearn/tasks/task.py @@ -75,6 +75,7 @@ def run(self): # if isinstance(self.trainer.data_loader, list): self.trainer.predict( loader=self.trainer.data_loader, split="predict", results_dir=results_dir, labels=self.config["task"]["labels"], + vmap_pred = self.config["task"].get("vmap_pred", False) ) # else: # self.trainer.predict( From 36b5e954e7867e4afb1f26b4e38027b2ad523ec1 Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Mon, 12 Feb 2024 22:43:18 -0500 Subject: [PATCH 3/3] Fix formatting --- matdeeplearn/trainers/property_trainer.py | 52 +++++++++++------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index 509b27d3..374dabf9 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -236,13 +236,13 @@ def validate(self, split="val"): def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True, vmap_pred = False): for mod in self.model: mod.eval() - if vmap_pred: - params, buffers = stack_module_state(self.model) + if vmap_pred: + params, buffers = stack_module_state(self.model) base_model = copy.deepcopy(self.model[0]) base_model = base_model.to('meta') - # TODO: Allow to work with pos_grad and cell_grad - def fmodel(params, buffers, x): - return functional_call(base_model, (params, buffers), (x,))['output'] + # TODO: Allow to work with pos_grad and cell_grad + def fmodel(params, buffers, x): + return functional_call(base_model, (params, buffers), (x,))['output'] # assert isinstance(loader, torch.utils.data.dataloader.DataLoader) @@ -264,29 +264,29 @@ def fmodel(params, buffers, x): loader_iter = iter(loader) for i in range(0, len(loader_iter)): batch = next(loader_iter).to(self.rank) - out = {} - out_stack={} - if not vmap_pred: - out_list = self._forward([batch]) - for key in out_list[0].keys(): - temp = [o[key] for o in out_list] - if temp[0] is not None: - out_stack[key] = torch.stack(temp) - out[key] = torch.mean(out_stack[key], dim=0) - out[key+"_std"] = torch.std(out_stack[key], dim=0) - else: - out[key] = None - out[key+"_std"] = None - batch_p = [o["output"].data.cpu().numpy() for o in out_list] + out = {} + out_stack={} + if not vmap_pred: + out_list = self._forward([batch]) + for key in out_list[0].keys(): + temp = [o[key] for o in out_list] + if temp[0] is not None: + out_stack[key] = torch.stack(temp) + out[key] = torch.mean(out_stack[key], dim=0) + out[key+"_std"] = torch.std(out_stack[key], dim=0) + else: + out[key] = None + out[key+"_std"] = None + batch_p = [o["output"].data.cpu().numpy() for o in out_list] - else: - out_list = vmap(fmodel, in_dims = (0, 0, None))(self.params, self.buffers, batch) - out["output"] = torch.mean(out_list, dim = 0) - out["output_std"] = torch.std(out_list, dim = 0) - batch_p = [out_list[o].cpu().numpy() for o in range(out_list.size()[0])] - + else: + out_list = vmap(fmodel, in_dims = (0, 0, None))(self.params, self.buffers, batch) + out["output"] = torch.mean(out_list, dim = 0) + out["output_std"] = torch.std(out_list, dim = 0) + batch_p = [out_list[o].cpu().numpy() for o in range(out_list.size()[0])] + batch_p_mean = out["output"].cpu().numpy() - batch_stds = out["output_std"].cpu().numpy() + batch_stds = out["output_std"].cpu().numpy() batch_ids = batch.structure_id if labels == True: