diff --git a/configs/examples/config_potnet.yml b/configs/examples/config_potnet.yml new file mode 100644 index 00000000..634eafcb --- /dev/null +++ b/configs/examples/config_potnet.yml @@ -0,0 +1,135 @@ +trainer: property + +task: + #run_mode: train + identifier: my_train_job + parallel: False + # If seed is not set, then it will be random every time + seed: 12345678 + # Defaults to run directory if not specified + save_dir: + # continue from a previous job + continue_job: False + # spefcify if the training state is loaded: epochs, learning rate, etc + load_training_state: False + # Path to the checkpoint.pt file + checkpoint_path: + # Whether to write predictions to csv file. E.g. ["train", "val", "test"] + write_output: [train, val, test] + # Frequency of writing to file; 0 denotes writing only at the end, 1 denotes writing every time + output_frequency: 0 + # Frequency of saving model .pt file; 0 denotes saving only at the end, 1 denotes saving every time, -1 denotes never saving; this controls both checkpoint and best_checkpoint + model_save_frequency: 0 + # Specify if labels are provided for the predict task + # labels: True + # Use amp mixed precision + use_amp: True + +model: + name: CGCNN + # model attributes + dim1: 100 + dim2: 150 + pre_fc_count: 1 + gc_count: 4 + post_fc_count: 3 + pool: global_add_pool + pool_order: early + batch_norm: True + batch_track_stats: True + act: relu + dropout_rate: 0.0 + # Compute edge indices on the fly in the model forward + otf_edge_index: False + # Compute edge attributes on the fly in the model forward + otf_edge_attr: False + # Compute node attributes on the fly in the model forward + otf_node_attr: False + # 1 indicates normal behavior, larger numbers indicate the number of models to be used + model_ensemble: 1 + # compute gradients w.r.t to positions and cell, requires otf_edge_attr=True + gradient: False + +optim: + max_epochs: 200 + max_checkpoint_epochs: 0 + lr: 0.001 + # Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper + loss: + loss_type: TorchLossWrapper + loss_args: {loss_fn: l1_loss} + # gradient clipping value + clip_grad_norm: 10 + batch_size: 100 + optimizer: + optimizer_type: AdamW + optimizer_args: {} + scheduler: + scheduler_type: ReduceLROnPlateau + scheduler_args: {mode: min, factor: 0.8, patience: 10, min_lr: 0.00001, threshold: 0.0002} + #Training print out frequency (print per n number of epochs) + verbosity: 5 + # tdqm progress bar per batch in the epoch + batch_tqdm: False + +dataset: + name: test_data + # Whether the data has already been processed and a data.pt file is present from a previous run + processed: False + # Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path} + src: data/test_data/data_graph_scalar.json + # Path to target file within data_path - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path} or left blank when the dataset is a single json file + # Example: target_path: "data/raw_graph_scalar/targets.csv" + target_path: + # Path to save processed data.pt file + pt_path: data/ + # Either "node" or "graph" level + prediction_level: graph + + transforms: + - name: GetY + args: + # index specifies the index of a target vector to predict, which is useful when there are multiple property labels for a single dataset + # For example, an index: 0 (default) will use the first entry in the target vector + # if all values are to be predicted simultaneously, then specify index: -1 + index: -1 + otf_transform: True # Optional parameter, default is True + # Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html) + data_format: json + # specify if additional attributes to be loaded into the dataset from the .json file; e.g. additional_attributes: [forces, stress] + additional_attributes: + # Print out processing info + verbose: True + # Index of target column in targets.csv + # graph specific settings + preprocess_params: + # one of mdl (minimum image convention), ocp (all neighbors included), inf (infinite potentials in addition to mdl) + edge_calc_method: inf + # determine if edges are computed, if false, then they need to be computed on the fly + preprocess_edges: True + # determine if edge attributes are computed during processing, if false, then they need to be computed on the fly + preprocess_edge_features: True + # determine if node attributes are computed during processing, if false, then they need to be computed on the fly + preprocess_node_features: True + # distance cutoff to determine if two atoms are connected by an edge + cutoff_radius : 8.0 + # maximum number of neighbors to consider (usually an arbitrarily high number to consider all neighbors) + n_neighbors : 250 + # number of pbc offsets to consider when determining neighbors (usually not changed) + num_offsets: 2 + # dimension of node attributes + node_dim : 100 + # dimension of edge attributes + edge_dim : 100 + # whether or not to add self-loops + self_loop: True + # Method of obtaining atom dictionary: available: (onehot) + node_representation: onehot + # Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html + num_workers: 0 + # Where the dataset is loaded; either "cpu" or "cuda" + dataset_device: cpu + # Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data) + train_ratio: 0.8 + val_ratio: 0.05 + test_ratio: 0.15 diff --git a/matdeeplearn/preprocessor/helpers.py b/matdeeplearn/preprocessor/helpers.py index 2b15349d..320404a4 100644 --- a/matdeeplearn/preprocessor/helpers.py +++ b/matdeeplearn/preprocessor/helpers.py @@ -4,7 +4,7 @@ import sys from itertools import combinations, product from pathlib import Path -from typing import Literal +from typing import Literal, Optional import ase import numpy as np @@ -17,6 +17,14 @@ from torch_scatter import scatter_min, segment_coo, segment_csr from torch_sparse import SparseTensor +from matdeeplearn.preprocessor.inf_functions.series import ( + cython_gsl_sf_gamma, + cython_gsl_sf_gamma_inc, + cython_upper_bessel, + cython_upper_bessel_k, +) + + def calculate_edges_master( method: Literal["ase", "ocp", "mdl"], r: float, @@ -30,7 +38,7 @@ def calculate_edges_master( experimental_distance: bool = False, device: torch.device = torch.device("cpu"), ) -> dict[str, torch.Tensor]: - """Generates edges using one of three methods (ASE, OCP, or MDL implementations) due to limitations of each method. + """Generates edges using one of three methods (ASE, OCP, MDL, or INF implementations) due to limitations of each method. Args: r (float): cutoff radius n_neighbors (int): number of neighbors to consider @@ -39,13 +47,12 @@ def calculate_edges_master( pos (torch.Tensor): positions of atom in unit cell """ - out = dict() neighbors = torch.empty(0) cell_offset_distances = torch.empty(0) - #check if cell consists of all zeros; if a cell is not present when processing input data, it is set to torch.zeros() - if not torch.any(cell>0.0): + # check if cell consists of all zeros; if a cell is not present when processing input data, it is set to torch.zeros() + if not torch.any(cell > 0.0): cell = None method = "mdl" @@ -62,18 +69,24 @@ def calculate_edges_master( # get into correct shape for model stage edge_vec = edge_vec[edge_index[0], edge_index[1]] - - #elif method == "ase": + + # elif method == "ase": # edge_index, cell_offsets, edge_weights, edge_vec = calculate_edges_ase( # all_neighbors, r, n_neighbors, structure_id, cell.squeeze(0), pos # ) - + elif method == "ocp": # OCP requires a different format for the cell cell = cell.view(1, 3, 3) - + edge_index, cell_offsets, neighbors = radius_graph_pbc( - r, n_neighbors, pos, cell, torch.tensor([len(pos)], device = device), [True, True, True], offset_number + r, + n_neighbors, + pos, + cell, + torch.tensor([len(pos)], device=device), + [True, True, True], + offset_number, ) ocp_out = get_pbc_distances( @@ -91,6 +104,60 @@ def calculate_edges_master( cell_offsets = ocp_out["offsets"] edge_vec = ocp_out["distance_vec"] + elif method == "inf": + coefficients=[-0.801, -0.074, 0.145] + + u = ( + torch.arange(0, z.size(0), 1) + .unsqueeze(1) + .repeat((1, z.size(0))) + .flatten() + .long() + ) + v = ( + torch.arange(0, z.size(0), 1) + .unsqueeze(0) + .repeat((z.size(0), 1)) + .flatten() + .long() + ) + edge_index = torch.stack([u, v]) + + out["inf_edge_index"] = edge_index + + vecs = ( + pos[u.flatten().numpy().astype(np.int32)] + - pos[v.flatten().numpy().astype(np.int32)] + ) + + potentials = calculate_inf_potentials( + vecs.cpu().detach().numpy(), + np.squeeze(cell.cpu().detach().numpy()), + R=int(r), + ) + + cutoff_distance_matrix, cell_offsets, edge_vec = get_cutoff_distance_matrix( + pos, + cell, + 10000000, + 10000000, + offset_number=offset_number, + ) + + edge_index, edge_weights = dense_to_sparse(cutoff_distance_matrix) + + edge_index, edge_weights,_ = add_selfloop(len(z),edge_index,edge_weights,cutoff_distance_matrix,True) + # get into correct shape for model stage + edge_vec = edge_vec[edge_index[0], edge_index[1]] + + weightedSum = sum( + [potentials[i] * coefficients[i] for i in range(len(coefficients))] + ) + + out["inf_edge_attr"] = ( + torch.from_numpy(weightedSum).unsqueeze(-1).to(torch.float32) + ) + out["edge_index"] = edge_index out["edge_weights"] = edge_weights out["cell_offsets"] = cell_offsets @@ -99,6 +166,7 @@ def calculate_edges_master( return out + @contextlib.contextmanager def prof_ctx(): """Primitive debug tool which allows profiling of PyTorch code""" @@ -110,6 +178,7 @@ def prof_ctx(): logging.debug(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10)) + def argmax(arr: list[dict], key: str) -> int: """List of Dict argmax utility function Args: @@ -158,6 +227,7 @@ def get_mask( mask = torch.argwhere(cond1 & cond2).squeeze(1) return mask + def threshold_sort(all_distances: torch.Tensor, r: float, n_neighbors: int): # A = all_distances.clone().detach() A = all_distances @@ -259,6 +329,7 @@ def clean_up(data_list, attr_list): for attr in removable_attrs: delattr(data, attr) + def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = "cpu"): """ Get the periodic boundary condition (PBC) offsets for a unit cell @@ -279,6 +350,7 @@ def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = "cpu"): offsets = torch.tensor(offsets, device=device, dtype=torch.float) return offsets @ cell, offsets + def get_distances_pbc( positions: torch.Tensor, offsets: torch.Tensor, @@ -301,11 +373,11 @@ def get_distances_pbc( mic: bool minimum image convention """ - + # convert numpy array to torch tensors n_atoms = len(positions) n_cells = len(offsets[0]) - + pos1 = positions.view(-1, 1, 1, 3).expand(-1, n_atoms, n_cells, 3) pos2 = positions.view(1, -1, 1, 3).expand(n_atoms, -1, n_cells, 3) offsets = offsets.view(-1, n_cells, 3).expand(pos2.shape[0], n_cells, 3) @@ -323,7 +395,6 @@ def get_distances_pbc( # get minimum min_atomic_distances, min_indices = torch.min(atomic_distances, dim=-1) expanded_min_indices = min_indices.clone().detach() - atom_rij = (pos1 - pos2).squeeze(2) @@ -334,6 +405,7 @@ def get_distances_pbc( return min_atomic_distances, min_indices, atom_rij + def get_distances( positions: torch.Tensor, device: str = "cpu", @@ -349,9 +421,7 @@ def get_distances( return atomic_distances, atom_rij -def get_cutoff_distance_matrix( - pos, cell, r, n_neighbors, offset_number=3 -): +def get_cutoff_distance_matrix(pos, cell, r, n_neighbors, offset_number=3): """ get the distance matrix TODO: need to tune this for elongated structures @@ -372,23 +442,25 @@ def get_cutoff_distance_matrix( max number of neighbors to be considered """ device = pos.device - + if cell != None: cells, cell_coors = get_pbc_cells(cell, offset_number, device=device) - distance_matrix, min_indices, atom_rij = get_distances_pbc(pos, cells, device=device) - + distance_matrix, min_indices, atom_rij = get_distances_pbc( + pos, cells, device=device + ) + cutoff_distance_matrix = threshold_sort(distance_matrix, r, n_neighbors) - + # if image_selfloop: # # output of threshold sort has diagonal == 0 # # fill in the original values # self_loop_diag = distance_matrix.diagonal() # cutoff_distance_matrix.diagonal().copy_(self_loop_diag) - + all_cell_offsets = cell_coors[torch.flatten(min_indices)] all_cell_offsets = all_cell_offsets.view(len(pos), -1, 3) # cell_offsets = all_cell_offsets[cutoff_distance_matrix != 0] - + # self loops will always have cell of (0,0,0) # N: no of selfloops; M: no of non selfloop edges # self loops are the last N edge_index pairs @@ -398,11 +470,11 @@ def get_cutoff_distance_matrix( # get cells for edges except for self loops cell_offsets[:n_edges, :] = all_cell_offsets[cutoff_distance_matrix != 0] cell_offsets = cell_offsets[:n_edges] - + elif cell == None: - distance_matrix, atom_rij = get_distances(pos, device=device) + distance_matrix, atom_rij = get_distances(pos, device=device) cutoff_distance_matrix = threshold_sort(distance_matrix, r, n_neighbors) - cell_offsets = torch.zeros((1,3)) + cell_offsets = torch.zeros((1, 3)) return cutoff_distance_matrix, cell_offsets, atom_rij @@ -431,7 +503,8 @@ def add_selfloop( def node_rep_one_hot(Z): - return F.one_hot(Z - 1, num_classes = 100) + return F.one_hot(Z - 1, num_classes=100) + def node_rep_from_file(node_representation="onehot"): node_rep_path = Path(__file__).parent @@ -455,7 +528,10 @@ def node_rep_from_file(node_representation="onehot"): return loaded_rep -def generate_node_features(input_data, n_neighbors, device, use_degree=False, node_rep_func = node_rep_one_hot): + +def generate_node_features( + input_data, n_neighbors, device, use_degree=False, node_rep_func=node_rep_one_hot +): if isinstance(input_data, Data): input_data.x = node_rep_func(input_data.z) if use_degree: @@ -466,8 +542,8 @@ def generate_node_features(input_data, n_neighbors, device, use_degree=False, no # minus 1 as the reps are 0-indexed but atomic number starts from 1 data.x = node_rep_func(data.z).float() - #for i, data in enumerate(input_data): - #input_data[i] = one_hot_degree(data, n_neighbors) + # for i, data in enumerate(input_data): + # input_data[i] = one_hot_degree(data, n_neighbors) def generate_edge_features(input_data, edge_steps, r, device): @@ -481,6 +557,19 @@ def generate_edge_features(input_data, edge_steps, r, device): input_data[i].edge_attr = distance_gaussian( input_data[i].edge_descriptor["distance"] ) + +def generate_edge_features_inf(input_data, edge_steps, r, device): + gaussian = GaussianSmearing(0, 1, edge_steps//2, 0.2, device=device) + + if isinstance(input_data, Data): + input_data = [input_data] + + normalize_edge_cutoff(input_data, "distance", r) + for i, data in enumerate(input_data): + input_data[i].edge_attr = torch.cat((gaussian( + input_data[i].edge_descriptor["distance"] + ),torch.squeeze(gaussian(input_data[i].inf_edge_attr))),dim=-1) + def triplets( edge_index, num_nodes, @@ -488,8 +577,9 @@ def triplets( row, col = edge_index # j->i value = torch.arange(row.size(0), device=row.device) - adj_t = SparseTensor(row=col, col=row, value=value, - sparse_sizes=(num_nodes, num_nodes)) + adj_t = SparseTensor( + row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes) + ) adj_t_row = adj_t[row] num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) @@ -504,7 +594,8 @@ def triplets( idx_kj = adj_t_row.storage.value()[mask] idx_ji = adj_t_row.storage.row()[mask] - return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji + return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji + def triplets_pbc(edge_index, cell_offsets, num_nodes): """ @@ -550,10 +641,8 @@ def compute_bond_angles( Taken from the DimeNet implementation on OCP """ # Calculate triplets - if (offsets is None): - _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = tripletsOld( - edge_index, num_nodes - ) + if offsets is None: + _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = tripletsOld(edge_index, num_nodes) else: idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( edge_index, offsets.to(device=edge_index.device), num_nodes @@ -563,7 +652,7 @@ def compute_bond_angles( pos_i = pos[idx_i] pos_j = pos[idx_j] - if (offsets is not None): + if offsets is not None: offsets = offsets.to(pos.device) pos_ji, pos_kj = ( pos[idx_j] - pos_i + offsets[idx_ji], @@ -581,7 +670,8 @@ def compute_bond_angles( angle = torch.atan2(b, a) return angle, idx_kj, idx_ji - + + def get_pbc_distances( pos, edge_index, @@ -635,6 +725,7 @@ def get_pbc_distances( return out + def get_max_neighbors_mask(natoms, index, atom_distance, max_num_neighbors_threshold): """ From https://github.com/Open-Catalyst-Project/ocp/blob/main/ocpmodels/common/utils.py @@ -709,7 +800,8 @@ def get_max_neighbors_mask(natoms, index, atom_distance, max_num_neighbors_thres mask_num_neighbors.index_fill_(0, index_sort, True) return mask_num_neighbors, num_neighbors_image - + + def radius_graph_pbc( radius: float, max_num_neighbors_threshold: int, @@ -811,9 +903,13 @@ def radius_graph_pbc( # if the required repetitions are very different between images # (which they usually are). Changing this to sparse (scatter) operations # might be worth the effort if this function becomes a bottleneck. - #max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()] - max_rep = [min(rep_a1.max().detach(), offset_number), min(rep_a2.max().detach(), offset_number), min(rep_a3.max().detach(), offset_number)] - + # max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()] + max_rep = [ + min(rep_a1.max().detach(), offset_number), + min(rep_a2.max().detach(), offset_number), + min(rep_a3.max().detach(), offset_number), + ] + # Tensor of unit cells cells_per_dim = [ torch.arange(-rep, rep + 1, device=device, dtype=torch.float) for rep in max_rep @@ -876,6 +972,7 @@ def radius_graph_pbc( return edge_index, unit_cell, num_neighbors_image + def calculate_edges_ase( all_neighbors: bool, r: float, @@ -961,4 +1058,406 @@ def calculate_edges_ase( return edge_index, cell_offsets, edge_weights, edge_vec +def calculate_inf_potentials( + v, + Omega, + R=3, +): + """Calculates the infinite potentials + Args: + v (np.ndarray): vecs + Omega (np.ndarray): unit cell + R (torch.Tensor): cutoff radius + """ + def zeta_cal(v, w, vecs, vecs_inv, d, det, p=1.0, eps=1e-12): + result = sum( + np.e ** (2 * np.pi * 1.0j * vecs @ w) + * cython_upper_bessel(-p, np.linalg.norm(vecs + v, axis=1) ** 2, 0, eps) + + np.e ** (2 * np.pi * 1.0j * v @ w) + / det + * np.pi ** (d / 2) + * np.e ** (-2 * np.pi * 1.0j * vecs_inv @ v) + * cython_upper_bessel( + p - d / 2, + np.pi**2 * np.linalg.norm(vecs_inv + w, axis=1) ** 2, + 0, + eps, + ) + ) + + if (v == 0).all(): + result = result - 1.0 / p + else: + result = result + cython_upper_bessel_k(-p, np.linalg.norm(v) ** 2, 0, eps) + + if (w == 0).all(): + result = result - np.pi ** (d / 2) / ((d / 2 - p) * det) + else: + result = result + np.e ** (2 * np.pi * 1.0j * v @ w) / det * np.pi ** ( + d / 2 + ) * cython_upper_bessel_k( + p - d / 2, np.pi**2 * np.linalg.norm(w) ** 2, 0, eps + ) + return result + + def epstein(v, w, Omega, param=1.0, R=3, eps=1e-12, parallel=False, verbose=False): + d = Omega.shape[0] + + assert len(np.shape(v)) == len(np.shape(w)) + if len(np.shape(v)) == 1: + v = [v] + w = [w] + + v = np.array(v, dtype=np.double) + w = np.array(w, dtype=np.double) + + num_vectors = v.shape[0] + + # normalization + det = abs(np.linalg.det(Omega)) + assert det > 0 + + gamma_norm = det ** (1.0 / d) + Omega = Omega / gamma_norm + Omega_inv = np.linalg.inv(Omega).T + v = v / gamma_norm + w = w * gamma_norm + det = 1.0 + + gamma_p = cython_gsl_sf_gamma(param) + + products = np.array( + [ + l + for l in itertools.product(*[list(range(-R, R + 1)) for _ in range(d)]) + if any(l) + ] + ) + vecs = products @ Omega + vecs_inv = products @ Omega_inv + + if verbose: + for i in range(num_vectors): + rounds = np.array( + [ + l + for l in itertools.product( + *[list(range(-1, 2)) for _ in range(d)] + ) + if any(l) + ] + ) + _, s1, _ = np.linalg.svd(Omega) + minor_minus1 = ( + np.clip(s1[-1] * R - np.linalg.norm(v[i]), a_min=0, a_max=np.inf) + ** 2 + ) + error_radius1 = np.sqrt(minor_minus1) + rho1 = np.min(np.linalg.norm(rounds @ Omega, axis=1)) + error1 = ( + d + / 2 + * (2 / rho1) ** d + * cython_gsl_sf_gamma_inc(d / 2, (error_radius1 - rho1 / 2) ** 2) + ) + + _, s2, _ = np.linalg.svd(Omega_inv) + minor_minus2 = ( + np.clip(s2[-1] * R - np.linalg.norm(w[i]), a_min=0, a_max=np.inf) + ** 2 + ) + error_radius2 = np.sqrt(np.pi**2 * minor_minus2) + rho2 = np.pi * np.min(np.linalg.norm(rounds @ Omega_inv, axis=1)) + error2 = ( + d + / 2 + * (2 / rho2) ** d + * cython_gsl_sf_gamma_inc(d / 2, (error_radius2 - rho2 / 2) ** 2) + ) + print( + "Error upper bound for " + + str(i) + + " vector is " + + str(error1 + error2) + ) + + values = np.array( + [ + zeta_cal(v[i], w[i], vecs, vecs_inv, d, det, p=param, eps=eps).real + for i in range(num_vectors) + ], + dtype=np.double, + ) + + return values * gamma_norm ** (-2 * param) / gamma_p + + # Coulomb + def zeta(v, Omega, param=1.0, R=3, eps=1e-12, parallel=False, verbose=False): + return epstein( + v, + np.zeros_like(v), + Omega, + param=param, + R=R, + eps=eps, + parallel=parallel, + verbose=verbose, + ) + + def exp_cal(v, vecs, vecs_inv, d, det, B, eps=1e-12): + return sum( + np.e ** (2 * np.pi * 1.0j * vecs_inv @ v) + / det + * cython_upper_bessel( + -0.5 - d / 2, B + np.pi * np.linalg.norm(vecs_inv, axis=1) ** 2, 0, eps + ) + + cython_upper_bessel( + 0.5, np.pi * np.linalg.norm(vecs + v, axis=1) ** 2, B, eps + ) + ).real + + # Pauli + def exp(v, Omega, param=1.0, R=3, eps=1e-12, parallel=False, verbose=False): + d = Omega.shape[0] + + if len(np.shape(v)) == 1: + v = [v] + + v = np.array(v, dtype=np.double) + num_vectors = v.shape[0] + + det = abs(np.linalg.det(Omega)) + assert det > 0 + + gamma_norm = det ** (1.0 / d) + Omega = Omega / gamma_norm + Omega_inv = np.linalg.inv(Omega).T + v = v / gamma_norm + det = 1.0 + param = param * np.sqrt(gamma_norm) + + products = np.array( + [l for l in itertools.product(*[list(range(-R, R + 1)) for _ in range(d)])] + ) + + vecs = products @ Omega + vecs_inv = products @ Omega_inv + + if verbose: + for i in range(num_vectors): + rounds = np.array( + [ + l + for l in itertools.product( + *[list(range(-1, 2)) for _ in range(d)] + ) + if any(l) + ] + ) + _, s1, _ = np.linalg.svd(Omega) + minor_minus1 = ( + np.clip(s1[-1] * R - np.linalg.norm(v[i]), a_min=0, a_max=np.inf) + ** 2 + ) + error_radius1 = np.sqrt(np.pi * minor_minus1) + rho1 = np.sqrt(np.pi) * np.min(np.linalg.norm(rounds @ Omega, axis=1)) + error1 = ( + d + / 2 + * (2 / rho1) ** d + * cython_gsl_sf_gamma_inc(d / 2, (error_radius1 - rho1 / 2) ** 2) + ) + + _, s2, _ = np.linalg.svd(Omega_inv) + minor_minus2 = np.clip(s2[-1] * R, a_min=0, a_max=np.inf) ** 2 + error_radius2 = np.sqrt(np.pi * minor_minus2) + rho2 = np.sqrt(np.pi) * np.min( + np.linalg.norm(rounds @ Omega_inv, axis=1) + ) + error2 = ( + d + / 2 + * (2 / rho2) ** d + * cython_gsl_sf_gamma_inc(d / 2, (error_radius2 - rho2 / 2) ** 2) + ) + print( + "Error upper bound for " + + str(i) + + " vector is " + + str(error1 + error2) + ) + + B = param**2 / 4.0 / np.pi + + values = np.array( + [exp_cal(v[i], vecs, vecs_inv, d, det, B, eps) for i in range(num_vectors)], + dtype=np.double, + ) + + return values * param / 2.0 / np.pi + + # TODO: Add error bound approximation for LJ potential + def lj(v, Omega, param=1.0, R=3, eps=1e-12, parallel=False, verbose=False): + if verbose: + raise NotImplementedError( + "Error bound for LJ potential is not implemented yet" + ) + return param**12 * zeta( + v, Omega, param=6.0, R=R, eps=eps, parallel=parallel + ) - param**6 * zeta(v, Omega, param=3.0, R=R, eps=eps, parallel=parallel) + + # TODO: Add error bound approximation for morse potential + def morse( + v, Omega, param=1.0, re=1.0, R=3, eps=1e-12, parallel=False, verbose=False + ): + if verbose: + raise NotImplementedError( + "Error bound for morse potential is not implemented yet" + ) + return np.exp(2.0 * param * re) * exp( + v, Omega, param=2.0 * param, R=R, eps=eps, parallel=parallel + ) - 2.0 * np.exp(param * re) * exp( + v, Omega, param=param, R=R, eps=eps, parallel=parallel + ) + + def screened_coulomb_cal(v, vecs, vecs_inv, d, det, B, eps=1e-12): + result = sum( + np.e ** (2 * np.pi * 1.0j * vecs_inv @ v) + * np.pi ** (d / 2) + / det + * cython_upper_bessel( + 0.5 - d / 2, + B + np.pi**2 * np.linalg.norm(vecs_inv, axis=1) ** 2, + 0, + eps, + ) + + cython_upper_bessel(-0.5, np.linalg.norm(vecs + v, axis=1) ** 2, B, eps) + ).real + if (v == 0).all(): + result = result + B**0.5 * ( + cython_gsl_sf_gamma(-0.5) - cython_gsl_sf_gamma_inc(-0.5, B) + ) + else: + result = result + cython_upper_bessel_k( + -0.5, np.linalg.norm(v) ** 2, B, eps + ) + return result + + # TODO: Add error bound approximation for screened coulomb potential + def screened_coulomb( + v, Omega, param=1.0, R=3, eps=1e-12, parallel=False, verbose=False + ): + d = Omega.shape[0] + + if len(np.shape(v)) == 1: + v = [v] + + v = np.array(v, dtype=np.double) + num_vectors = v.shape[0] + + det = abs(np.linalg.det(Omega)) + assert det > 0 + + gamma_norm = det ** (1.0 / d) + Omega = Omega / gamma_norm + Omega_inv = np.linalg.inv(Omega).T + v = v / gamma_norm + det = 1.0 + + param = param * np.sqrt(gamma_norm) + + products = np.array( + [ + l + for l in itertools.product(*[list(range(-R, R + 1)) for _ in range(d)]) + if any(l) + ] + ) + + vecs = products @ Omega + vecs_inv = products @ Omega_inv + + B = param**2 + + if verbose: + raise NotImplementedError( + "Error bound for screened coulomb potential is not implemented yet" + ) + + values = np.array( + [ + screened_coulomb_cal(v[i], vecs, vecs_inv, d, det, B, eps) + for i in range(num_vectors) + ], + dtype=np.double, + ) + + return values / np.sqrt(np.pi) / np.sqrt(gamma_norm) + + return np.vstack( + ( + zeta(v, Omega, param=0.5, R=R, eps=1e-12), # Coulomb + zeta(v, Omega, param=3.0, R=R, eps=1e-12), # LD + exp(v, Omega, param=3.0, R=R, eps=1e-12), # Pauli + ) + ) + +class RBFExpansion(torch.nn.Module): + """Expand interatomic distances with radial basis functions.""" + + def __init__( + self, + vmin: float = 0, + vmax: float = 8, + bins: int = 40, + lengthscale: Optional[float] = None, + type: str = "gaussian", + ): + """Register torch parameters for RBF expansion.""" + super().__init__() + self.vmin = vmin + self.vmax = vmax + self.bins = bins + self.register_buffer("centers", torch.linspace(vmin, vmax, bins)) + self.type = type + + if lengthscale is None: + # SchNet-style + # set lengthscales relative to granularity of RBF expansion + self.lengthscale = np.diff(self.centers).mean() + self.gamma = 1 / self.lengthscale + + else: + self.lengthscale = lengthscale + self.gamma = 1 / (lengthscale**2) + + def forward(self, distance: torch.Tensor) -> torch.Tensor: + """Apply RBF expansion to interatomic distance tensor.""" + base = self.gamma * (distance - self.centers) + if self.type == "gaussian": + return (-(base**2)).exp() + elif self.type == "quadratic": + return base**2 + elif self.type == "linear": + return base + elif self.type == "inverse_quadratic": + return 1.0 / (1.0 + base**2) + elif self.type == "multiquadric": + return (1.0 + base**2).sqrt() + elif self.type == "inverse_multiquadric": + return 1.0 / (1.0 + base**2).sqrt() + elif self.type == "spline": + return base**2 * (base + 1.0).log() + elif self.type == "poisson_one": + return (base - 1.0) * (-base).exp() + elif self.type == "poisson_two": + return (base - 2.0) / 2.0 * base * (-base).exp() + elif self.type == "matern32": + return (1.0 + 3**0.5 * base) * (-(3**0.5) * base).exp() + elif self.type == "matern52": + return (1.0 + 5**0.5 * base + 5 / 3 * base**2) * ( + -(5**0.5) * base + ).exp() + else: + raise Exception("No Implemented Radial Basis Method") \ No newline at end of file diff --git a/matdeeplearn/preprocessor/inf_functions/bessel.c b/matdeeplearn/preprocessor/inf_functions/bessel.c new file mode 100644 index 00000000..81c3f6a4 --- /dev/null +++ b/matdeeplearn/preprocessor/inf_functions/bessel.c @@ -0,0 +1,147 @@ +#include +#include +#include +#include +#include + + +/* Codes are based on https://github.com/scafacos/scafacos */ + +/* Compute the incomplete Bessel-K function of order nu according to the paper + Richard M. Slevinsky and Hassan Safouhi. 2010. + A recursive algorithm for the G transformation and accurate computation of incomplete Bessel functions. + Appl. Numer. Math. 60, 12 (December 2010), 1411-1417. + http://dx.doi.org/10.1016/j.apnum.2010.04.005 +*/ + +#define FLOAT_PREC 1.0e-14 + + +int is_equal(double x, double y) { + return (fabs(x-y) < FLOAT_PREC); +} + +int is_zero(double x) { + return (fabs(x) < FLOAT_PREC); +} + +/*************************************************/ +/* wrappers for Bessel K and inc. */ +/*************************************************/ + +double upper_bessel_k(double nu, double x, double y, double eps) { + if (is_zero(x)) { + return pow(y, -nu) * (gsl_sf_gamma(nu) - gsl_sf_gamma_inc(nu, y)); + } + + int n = 2, n_max = 127; + double err = 1.0, val_new, val_old; + double N[4], D[4]; + + if (-21 <= nu && nu <= 21) { + if (x > 111) return 0.0; + if (x < y) if (x * y > 58.0 * 58.0) return 0.0; + } else { + const double bound = 1e-50; + + if (nu >= -1) { + if (x > gsl_sf_lambert_W0(1 / bound)) return 0.0; + } else { + double fak = 1.0; + for (int t = 1; t < -nu; t++) + fak *= t; + + if (fak * exp(1 - x) * pow(x, -1) < bound) return 0.0; + } + } + + if (is_zero(y)) + return pow(x, nu) * gsl_sf_gamma_inc(-nu, x); + + if (is_zero(nu)) { + if (pow(x, 2) + pow(y, 2) < pow(0.75, 2)) { + int k = 0; + double fak = 1.0; + double z = 0.0; + while (exp(-x) * pow(y, k + 1) / (x * (k + 1) * fak) > eps) { + z += pow(-1, k) * pow(x * y, k) * gsl_sf_gamma_inc(-k, x) / fak; + k += 1; + fak *= k; + } + return z; + } + } + + N[0] = 0.0; + N[1] = 1.0; + N[2] = 0.5 * (x + nu + 3.0 - y) * N[1]; + N[3] = (x + nu + 5.0 - y) * N[2] + (2.0 * y - nu - 2.0) * N[1]; + N[3] = N[3] / 3.0; + + D[0] = exp(x + y); + D[1] = (x + nu + 1.0 - y) * D[0]; + D[2] = 0.5 * (x + nu + 3.0 - y) * D[1] + 0.5 * (2.0 * y - nu - 1.0) * D[0]; + D[3] = (x + nu + 5.0 - y) * D[2] + (2.0 * y - nu - 2.0) * D[1] - y * D[0]; + D[3] = D[3] / 3.0; + + val_old = N[2] / D[2]; + val_new = N[3] / D[3]; + + err = fabs(val_new - val_old); + + while (err > eps) { + + if (fabs(val_new) < eps) + break; + + if (n >= n_max) { + break; + } + n++; + + val_old = val_new; + + N[0] = N[1]; + N[1] = N[2]; + N[2] = N[3]; + + D[0] = D[1]; + D[1] = D[2]; + D[2] = D[3]; + + N[3] = (x + nu + 1 + 2 * n - y) * N[2] + (2 * y - nu - n) * N[1] - y * N[0]; + N[3] = N[3] / (n + 1); + + D[3] = (x + nu + 1 + 2 * n - y) * D[2] + (2 * y - nu - n) * D[1] - y * D[0]; + D[3] = D[3] / (n + 1); + + val_new = N[3] / D[3]; + + if (isnan(val_new)) { + val_new = val_old; + break; + } + + if (is_zero(val_new)) { + val_new = val_old; + break; + } + + if (isinf(val_new)) { + val_new = val_old; + break; + } + + err = fabs(val_new - val_old); + } + + return val_new; +} + +/*************************************************/ +/* wrappers for Bessel K and low. */ +/*************************************************/ + +double lower_bessel_k(double nu, double x, double y, double eps) { + return upper_bessel_k(-nu, y, x, eps); +} diff --git a/matdeeplearn/preprocessor/inf_functions/gsl-latest.tar.gz b/matdeeplearn/preprocessor/inf_functions/gsl-latest.tar.gz new file mode 100644 index 00000000..8b20aebc Binary files /dev/null and b/matdeeplearn/preprocessor/inf_functions/gsl-latest.tar.gz differ diff --git a/matdeeplearn/preprocessor/inf_functions/header.h b/matdeeplearn/preprocessor/inf_functions/header.h new file mode 100644 index 00000000..951d7816 --- /dev/null +++ b/matdeeplearn/preprocessor/inf_functions/header.h @@ -0,0 +1,7 @@ +/** + * Collects all external functions used by cython + * (eliminates compiler warnings) + */ + +double upper_bessel_k(double, double, double, double); +double lower_bessel_k(double, double, double, double); diff --git a/matdeeplearn/preprocessor/inf_functions/series.pyx b/matdeeplearn/preprocessor/inf_functions/series.pyx new file mode 100644 index 00000000..8998b8d2 --- /dev/null +++ b/matdeeplearn/preprocessor/inf_functions/series.pyx @@ -0,0 +1,54 @@ +# cython: language_level=2 +cimport cython +import numpy + + +cdef extern from "header.h": + double upper_bessel_k(double, double, double, double); + double lower_bessel_k(double, double, double, double); + + +cdef extern from "gsl/gsl_sf_gamma.h": + double gsl_sf_gamma_inc(double, double); + double gsl_sf_gamma(double) + +@cython.boundscheck(False) +@cython.wraparound(False) +def cython_gammainc(a, x): + cdef int num_vectors + num_vectors = len(x) + results = numpy.zeros(num_vectors, dtype=numpy.double) + x = numpy.clip(x, a_min=0, a_max=5e2) + for i in range(num_vectors): + results[i] = gsl_sf_gamma_inc(a, x[i]) + return results + +@cython.boundscheck(False) +@cython.wraparound(False) +def cython_upper_bessel_k(nu, x, y, eps): + return upper_bessel_k(nu, x, y, eps) + +@cython.boundscheck(False) +@cython.wraparound(False) +def cython_lower_bessel_k(nu, x, y, eps): + return lower_bessel_k(nu, x, y, eps) + +@cython.boundscheck(False) +@cython.wraparound(False) +def cython_gsl_sf_gamma(p): + return gsl_sf_gamma(p) + +@cython.boundscheck(False) +@cython.wraparound(False) +def cython_gsl_sf_gamma_inc(p, x): + return gsl_sf_gamma_inc(p, x) + +@cython.boundscheck(False) +@cython.wraparound(False) +def cython_upper_bessel(nu, x, y, eps): + cdef int num_vectors + num_vectors = len(x) + results = numpy.zeros(num_vectors, dtype=numpy.double) + for i in range(num_vectors): + results[i] = upper_bessel_k(nu, x[i], y, eps) + return results diff --git a/matdeeplearn/preprocessor/inf_functions/setup.py b/matdeeplearn/preprocessor/inf_functions/setup.py new file mode 100644 index 00000000..955c5f54 --- /dev/null +++ b/matdeeplearn/preprocessor/inf_functions/setup.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +"""Setup script for the theta package +""" + +import numpy +from Cython.Build import cythonize +from Cython.Distutils import build_ext +from setuptools import Extension, find_packages, setup + +extensions = [ + Extension( + "series", + sources=["series.pyx", "bessel.c"], + include_dirs=[numpy.get_include(), "gsl/include"], + library_dirs=["gsl/lib"], + libraries=["gsl", "gslcblas"], + extra_compile_args=["-I./gsl/include"], + extra_link_args=["-L./gsl/lib"], + ) +] + +setup( + name="inf_functions", + author="kruskallin", + author_email="kruskallin@tamu.edu", + packages=find_packages(), + cmdclass={"build_ext": build_ext}, + ext_modules=cythonize(extensions), + install_requires=[ + "numpy >= 1.13", + ], + zip_safe=False, +) diff --git a/matdeeplearn/preprocessor/processor.py b/matdeeplearn/preprocessor/processor.py index 4e0c0712..f3a9d607 100644 --- a/matdeeplearn/preprocessor/processor.py +++ b/matdeeplearn/preprocessor/processor.py @@ -16,6 +16,7 @@ from matdeeplearn.preprocessor.helpers import ( clean_up, generate_edge_features, + generate_edge_features_inf, generate_node_features, get_cutoff_distance_matrix, calculate_edges_master, @@ -472,8 +473,13 @@ def get_data_list(self, dict_structures): # data.edge_descriptor["mask"] = cd_matrix_masked data.edge_descriptor["distance"] = edge_weights # data.distances = edge_weights - + # Infinite potentials + if self.edge_calc_method == "inf": + data.inf_edge_attr = edge_gen_out["inf_edge_attr"] + data.inf_edge_index = edge_gen_out["inf_edge_index"] + + # add additional attributes if self.additional_attributes: for attr in self.additional_attributes: @@ -485,7 +491,11 @@ def get_data_list(self, dict_structures): if self.preprocess_edge_features == True: logging.info("Generating edge features...") - generate_edge_features(data_list, self.edge_dim, self.r, device=self.device) + + if self.edge_calc_method == "inf": + generate_edge_features_inf(data_list, self.edge_dim, self.r, device=self.device) + else: + generate_edge_features(data_list, self.edge_dim, self.r, device=self.device) # compile non-otf transforms logging.debug("Applying transforms.")