diff --git a/src/seap/prediction/datasets.py b/src/seap/prediction/datasets.py index ac27f3a..1cfdd1d 100644 --- a/src/seap/prediction/datasets.py +++ b/src/seap/prediction/datasets.py @@ -7,11 +7,7 @@ # The data is generated using a random set of parameters, and the data is normalized. # The data is then saved to a file. -import os -import sys import random -import itertools - import numpy as np from scipy.special import factorial @@ -55,6 +51,11 @@ def __init__(self, l_max, rn_max, n_div, size=1.0): self.n_div = n_div self.rn_max = rn_max self.lm_set = sph.quantum_number(self.n) + # Precompute factorial matrix for radial normalization (depends only on rn_max) + ns = np.arange(self.rn_max) + n1n2 = ns[:, np.newaxis] + ns[np.newaxis, :] + self._fact_mat = factorial(n1n2 + 2) + self._n1n2 = n1n2 # Prepare cubic grid points and extract points in the inscribed sphere. grid_cart = grids.cartesian_grid(np.repeat(self.n_div, 3), size) grid_pol_in_ball, self.map_idx = grids.extract_inscribed_ball(grid_cart, size) @@ -75,17 +76,17 @@ def set_arrays(self, grid_in_pol): tuple Tuple containing arrays for radial vectors, radial powers, and spherical harmonics. """ - rpow = [] - rvec = [] - sphmat = [] - for r_pol in grid_in_pol: - r, theta, phi = r_pol - rvec.append(r) - rpow.append([r**n for n in range(self.rn_max)]) - sphmat.append([ - sph.spherical_harmonics(l, m, theta, phi) for l, m in self.lm_set - ]) - return np.array(rvec), np.array(rpow).T, np.array(sphmat).T + grid_in_pol = np.asarray(grid_in_pol) + rvec = grid_in_pol[:, 0] + theta = grid_in_pol[:, 1] + phi = grid_in_pol[:, 2] + # rpow: shape (rn_max, n_points) + rpow = rvec[np.newaxis, :] ** np.arange(self.rn_max)[:, np.newaxis] + # sphmat: shape (n_lm, n_points) — scipy sph_harm supports array theta/phi + sphmat = np.array([ + sph.spherical_harmonics(l, m, theta, phi) for l, m in self.lm_set + ]) + return rvec, rpow, sphmat def generate_learning_data(self, n_samples): """ @@ -101,20 +102,21 @@ def generate_learning_data(self, n_samples): tuple Tuple containing input data (x) and target data (y). """ - x = np.empty((n_samples, self.n_div**3), dtype=np.float32) - y = np.empty((n_samples, self.n**2 + self.rn_max + 1), dtype=np.float32) + # Phase 1: Generate all random parameters (cheap per-sample loop) + gammas = np.empty(n_samples) + ans = np.empty((n_samples, self.rn_max)) + c_all = np.empty((n_samples, self.n**2), dtype=np.float32) for i in range(n_samples): - random.seed() - # Generate random parameters for the radial part gamma, an = self.get_random_radial_params() - # Select random spherical harmonics coefficients - n_c = random.choice(range(1, len(self.lm_set))) + gammas[i] = gamma + ans[i] = an + n_c = random.choice(range(1, len(self.lm_set))) lm_selected = random.sample(self.lm_set, k=n_c) - c = self.assign_sph_coeffs(lm_selected).reshape(1, -1) - # Generate data based on parameters - data = self.params_to_boxdata(c, gamma, an) - x[i] = np.ravel(data) - y[i] = np.concatenate([np.ravel(c), [gamma], an]) + c_all[i] = self.assign_sph_coeffs(lm_selected) + + # Phase 2: Batch compute grid data (expensive, now vectorized) + x = self.params_to_boxdata(c_all, gammas, ans).astype(np.float32) + y = np.column_stack([c_all, gammas, ans]).astype(np.float32) return x, y def get_random_radial_params(self): @@ -127,11 +129,10 @@ def get_random_radial_params(self): Tuple containing gamma and normalized radial coefficients (an). """ gamma = random.uniform(1.0, 5.0) - an = [random.uniform(-1.0, 1.0) for k in range(self.rn_max)] - # Normalize the radial coefficients - norm = 0.0 - for n1, n2 in itertools.product(range(self.rn_max), repeat=2): - norm += (an[n1] * an[n2] * factorial(n1 + n2 + 2)) / (2 * gamma)**(n1 + n2 + 3) + an = np.array([random.uniform(-1.0, 1.0) for k in range(self.rn_max)]) + # Vectorized norm using precomputed factorial matrix + denom_mat = (2 * gamma) ** (self._n1n2 + 3) + norm = np.sum(np.outer(an, an) * self._fact_mat / denom_mat) an /= np.sqrt(norm) return gamma, an @@ -181,6 +182,5 @@ def params_to_boxdata(self, c, gamma, an): values = angular_parts * radial_parts # Assign values to grid points data = np.zeros((n_batch, self.n_div**3)) - for i in range(n_batch): - np.put(data[i], self.map_idx, values[i]) + data[:, self.map_idx] = values return data diff --git a/src/seap/prediction/grids.py b/src/seap/prediction/grids.py index 3a54f4b..48e62e5 100644 --- a/src/seap/prediction/grids.py +++ b/src/seap/prediction/grids.py @@ -1,6 +1,4 @@ -import math import numpy as np -from itertools import product # This module provides functions for generating grids of points in Cartesian and polar coordinates. # The functions are used for scatter plotting and for generating data for machine learning models. @@ -62,8 +60,9 @@ def cartesian_grid(n_div_list, size): steps = [float(size / n_div) for n_div in n_div_list] # Generate grid points for each axis xyz = [np.arange(-size/2.0 + step/2.0, size/2.0, step) for step in steps] - # Create a Cartesian product of the grid points - mesh = np.array(list(product(xyz[0], xyz[1], xyz[2]))) + # Create a Cartesian product of the grid points using meshgrid + grids = np.meshgrid(xyz[0], xyz[1], xyz[2], indexing='ij') + mesh = np.stack(grids, axis=-1).reshape(-1, 3) return mesh def extract_inscribed_ball(pos_cart, size): @@ -88,18 +87,12 @@ def extract_inscribed_ball(pos_cart, size): (array([[1.73205081, 0.95531662, 0.78539816]]), [1]) """ rmax = size / 2.0 - pos_cart_in_ball = [] - map_idx = [] - # Iterate over each position and check if it lies within the sphere - for ir, pos in enumerate(pos_cart): - if np.all(pos == 0.0): - continue - norm = np.linalg.norm(pos) - if norm <= rmax: - pos_cart_in_ball.append(pos) - map_idx.append(ir) + norms = np.linalg.norm(pos_cart, axis=1) + mask = (norms > 0) & (norms <= rmax) + map_idx = np.where(mask)[0].tolist() + pos_cart_in_ball = pos_cart[mask] # Convert Cartesian coordinates to polar coordinates - pos_pol_in_ball = cartesian2polar(np.array(pos_cart_in_ball)) + pos_pol_in_ball = cartesian2polar(pos_cart_in_ball) return pos_pol_in_ball, map_idx def cartesian2polar(pos_cart): @@ -131,8 +124,7 @@ def cartesian2polar(pos_cart): # Calculate azimuthal angle phi_tmp = np.arctan2(pos_cart[:, 1], pos_cart[:, 0]) # Convert azimuthal angle range from [-pi, pi] to [0, 2pi] - for i, phi in enumerate(phi_tmp): - pos_pol[i, 2] = math.fmod(phi + 2 * np.pi, 2 * np.pi) + pos_pol[:, 2] = np.fmod(phi_tmp + 2 * np.pi, 2 * np.pi) # Handle NaN values in polar angle nan_index = np.isnan(pos_pol[:, 1]) pos_pol[nan_index, 1] = 0