Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions src/seap/prediction/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
# The data is generated using a random set of parameters, and the data is normalized.
# The data is then saved to a file.

import os
import sys
import random
import itertools

import numpy as np
from scipy.special import factorial

Expand Down Expand Up @@ -55,6 +51,11 @@ def __init__(self, l_max, rn_max, n_div, size=1.0):
self.n_div = n_div
self.rn_max = rn_max
self.lm_set = sph.quantum_number(self.n)
# Precompute factorial matrix for radial normalization (depends only on rn_max)
ns = np.arange(self.rn_max)
n1n2 = ns[:, np.newaxis] + ns[np.newaxis, :]
self._fact_mat = factorial(n1n2 + 2)
self._n1n2 = n1n2
# Prepare cubic grid points and extract points in the inscribed sphere.
grid_cart = grids.cartesian_grid(np.repeat(self.n_div, 3), size)
grid_pol_in_ball, self.map_idx = grids.extract_inscribed_ball(grid_cart, size)
Expand All @@ -75,17 +76,17 @@ def set_arrays(self, grid_in_pol):
tuple
Tuple containing arrays for radial vectors, radial powers, and spherical harmonics.
"""
rpow = []
rvec = []
sphmat = []
for r_pol in grid_in_pol:
r, theta, phi = r_pol
rvec.append(r)
rpow.append([r**n for n in range(self.rn_max)])
sphmat.append([
sph.spherical_harmonics(l, m, theta, phi) for l, m in self.lm_set
])
return np.array(rvec), np.array(rpow).T, np.array(sphmat).T
grid_in_pol = np.asarray(grid_in_pol)
rvec = grid_in_pol[:, 0]
theta = grid_in_pol[:, 1]
phi = grid_in_pol[:, 2]
# rpow: shape (rn_max, n_points)
rpow = rvec[np.newaxis, :] ** np.arange(self.rn_max)[:, np.newaxis]
# sphmat: shape (n_lm, n_points) — scipy sph_harm supports array theta/phi
sphmat = np.array([
sph.spherical_harmonics(l, m, theta, phi) for l, m in self.lm_set
])
return rvec, rpow, sphmat

def generate_learning_data(self, n_samples):
"""
Expand All @@ -101,20 +102,21 @@ def generate_learning_data(self, n_samples):
tuple
Tuple containing input data (x) and target data (y).
"""
x = np.empty((n_samples, self.n_div**3), dtype=np.float32)
y = np.empty((n_samples, self.n**2 + self.rn_max + 1), dtype=np.float32)
# Phase 1: Generate all random parameters (cheap per-sample loop)
gammas = np.empty(n_samples)
ans = np.empty((n_samples, self.rn_max))
c_all = np.empty((n_samples, self.n**2), dtype=np.float32)
for i in range(n_samples):
random.seed()
# Generate random parameters for the radial part
gamma, an = self.get_random_radial_params()
# Select random spherical harmonics coefficients
n_c = random.choice(range(1, len(self.lm_set)))
gammas[i] = gamma
ans[i] = an
n_c = random.choice(range(1, len(self.lm_set)))
lm_selected = random.sample(self.lm_set, k=n_c)
c = self.assign_sph_coeffs(lm_selected).reshape(1, -1)
# Generate data based on parameters
data = self.params_to_boxdata(c, gamma, an)
x[i] = np.ravel(data)
y[i] = np.concatenate([np.ravel(c), [gamma], an])
c_all[i] = self.assign_sph_coeffs(lm_selected)

# Phase 2: Batch compute grid data (expensive, now vectorized)
x = self.params_to_boxdata(c_all, gammas, ans).astype(np.float32)
y = np.column_stack([c_all, gammas, ans]).astype(np.float32)
return x, y

def get_random_radial_params(self):
Expand All @@ -127,11 +129,10 @@ def get_random_radial_params(self):
Tuple containing gamma and normalized radial coefficients (an).
"""
gamma = random.uniform(1.0, 5.0)
an = [random.uniform(-1.0, 1.0) for k in range(self.rn_max)]
# Normalize the radial coefficients
norm = 0.0
for n1, n2 in itertools.product(range(self.rn_max), repeat=2):
norm += (an[n1] * an[n2] * factorial(n1 + n2 + 2)) / (2 * gamma)**(n1 + n2 + 3)
an = np.array([random.uniform(-1.0, 1.0) for k in range(self.rn_max)])
# Vectorized norm using precomputed factorial matrix
denom_mat = (2 * gamma) ** (self._n1n2 + 3)
norm = np.sum(np.outer(an, an) * self._fact_mat / denom_mat)
an /= np.sqrt(norm)
return gamma, an

Expand Down Expand Up @@ -181,6 +182,5 @@ def params_to_boxdata(self, c, gamma, an):
values = angular_parts * radial_parts
# Assign values to grid points
data = np.zeros((n_batch, self.n_div**3))
for i in range(n_batch):
np.put(data[i], self.map_idx, values[i])
data[:, self.map_idx] = values
return data
26 changes: 9 additions & 17 deletions src/seap/prediction/grids.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import math
import numpy as np
from itertools import product

# This module provides functions for generating grids of points in Cartesian and polar coordinates.
# The functions are used for scatter plotting and for generating data for machine learning models.
Expand Down Expand Up @@ -62,8 +60,9 @@ def cartesian_grid(n_div_list, size):
steps = [float(size / n_div) for n_div in n_div_list]
# Generate grid points for each axis
xyz = [np.arange(-size/2.0 + step/2.0, size/2.0, step) for step in steps]
# Create a Cartesian product of the grid points
mesh = np.array(list(product(xyz[0], xyz[1], xyz[2])))
# Create a Cartesian product of the grid points using meshgrid
grids = np.meshgrid(xyz[0], xyz[1], xyz[2], indexing='ij')
mesh = np.stack(grids, axis=-1).reshape(-1, 3)
return mesh

def extract_inscribed_ball(pos_cart, size):
Expand All @@ -88,18 +87,12 @@ def extract_inscribed_ball(pos_cart, size):
(array([[1.73205081, 0.95531662, 0.78539816]]), [1])
"""
rmax = size / 2.0
pos_cart_in_ball = []
map_idx = []
# Iterate over each position and check if it lies within the sphere
for ir, pos in enumerate(pos_cart):
if np.all(pos == 0.0):
continue
norm = np.linalg.norm(pos)
if norm <= rmax:
pos_cart_in_ball.append(pos)
map_idx.append(ir)
norms = np.linalg.norm(pos_cart, axis=1)
mask = (norms > 0) & (norms <= rmax)
map_idx = np.where(mask)[0].tolist()
pos_cart_in_ball = pos_cart[mask]
# Convert Cartesian coordinates to polar coordinates
pos_pol_in_ball = cartesian2polar(np.array(pos_cart_in_ball))
pos_pol_in_ball = cartesian2polar(pos_cart_in_ball)
return pos_pol_in_ball, map_idx

def cartesian2polar(pos_cart):
Expand Down Expand Up @@ -131,8 +124,7 @@ def cartesian2polar(pos_cart):
# Calculate azimuthal angle
phi_tmp = np.arctan2(pos_cart[:, 1], pos_cart[:, 0])
# Convert azimuthal angle range from [-pi, pi] to [0, 2pi]
for i, phi in enumerate(phi_tmp):
pos_pol[i, 2] = math.fmod(phi + 2 * np.pi, 2 * np.pi)
pos_pol[:, 2] = np.fmod(phi_tmp + 2 * np.pi, 2 * np.pi)
# Handle NaN values in polar angle
nan_index = np.isnan(pos_pol[:, 1])
pos_pol[nan_index, 1] = 0
Expand Down