Skip to content

Commit 8fd9319

Browse files
author
jngaravitoc
committed
re-strucuring basis
1 parent 7827017 commit 8fd9319

File tree

4 files changed

+471
-0
lines changed

4 files changed

+471
-0
lines changed

EXPtools/basis/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .profiles import Profiles
2+
from .basis_utils import *
3+
from .makemodel import *

EXPtools/basis/basis_utils.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import os
2+
import re
3+
import yaml
4+
import numpy as np
5+
import pyEXP
6+
from EXPtools.basis.makemodel import make_model
7+
8+
def load_basis(config_name, cache_dir=None):
9+
"""
10+
Load a basis configuration from a YAML file and initialize a Basis object.
11+
12+
Parameters
13+
----------
14+
config_name : str
15+
Path to the YAML configuration file. If the provided filename does not
16+
end with `.yaml`, the extension is automatically appended.
17+
cache_dir : str (optional)
18+
Path to modelname and cachename (assumes both are in the same directory)
19+
if None assumes it is the current working directory.
20+
21+
Returns
22+
-------
23+
basis : pyEXP.basis.Basis
24+
An initialized Basis object created from the configuration.
25+
26+
Raises
27+
------
28+
FileNotFoundError
29+
If the specified YAML file does not exist.
30+
"""
31+
32+
# Check file existence
33+
if not os.path.exists(config_name):
34+
raise FileNotFoundError(f"Configuration file not found: {config_name}")
35+
36+
# Load YAML safely
37+
with open(config_name, "r") as f:
38+
config = yaml.load(f, Loader=yaml.FullLoader)
39+
40+
if cache_dir:
41+
config = re.sub(r"(modelname:\s*)(\S+)", rf"\1{cache_dir}\2", config)
42+
config = re.sub(r"(cachename:\s*)(\S+)", rf"\1{cache_dir}\2", config)
43+
# Build basis from configuration
44+
basis = pyEXP.basis.Basis.factory(config)
45+
return basis
46+
47+
def check_basis_params(basis_id, **kwargs):
48+
"""
49+
Check that the required keyword arguments for a given basis are provided.
50+
51+
Parameters
52+
----------
53+
basis_id : str
54+
The identifier of the basis set.
55+
Accepted values are:
56+
- ``'sphereSL'`` : Spherical basis with Stäckel-like mapping.
57+
- ``'cylinder'`` : Cylindrical basis.
58+
**kwargs : dict
59+
Arbitrary keyword arguments corresponding to the basis parameters.
60+
The required keys depend on ``basis_id``:
61+
62+
- For ``'sphereSL'``:
63+
['Lmax', 'mmax', 'modelname', 'rmapping', 'cachename']
64+
65+
- For ``'cylinder'``:
66+
['acyl', 'hcyl', 'nmaxfid', 'lmaxfid', 'mmax', 'nmax', 'ncylodd',
67+
'ncylnx', 'ncylny', 'rnum', 'pmun', 'tnum', 'vflag', 'logr', 'cachename']
68+
69+
Returns
70+
-------
71+
bool
72+
Returns ``True`` if all mandatory parameters are present.
73+
74+
Raises
75+
------
76+
KeyError
77+
If one or more mandatory keyword arguments are missing for the selected basis.
78+
AttributeError
79+
If ``basis_id`` is not recognized (must be either 'sphereSL' or 'cylinder').
80+
81+
Examples
82+
--------
83+
check_basis_params('sphereSL', Lmax=4, nmax=4, modelname='hernquist.txt',
84+
... rmapping=1, cachename='cache_hernquist.txt')
85+
True
86+
87+
check_basis_params('cylinder', acyl=1.0, hcyl=2.0)
88+
Traceback (most recent call last):
89+
...
90+
KeyError: "Missing mandatory keyword arguments missing: [...]"
91+
"""
92+
93+
if basis_id == 'sphereSL':
94+
mandatory_keys = ['Lmax', 'nmax', 'modelname', 'rmapping', 'cachename']
95+
missing = [key for key in mandatory_keys if key not in kwargs]
96+
if missing:
97+
raise KeyError(f"Missing mandatory keyword arguments missing: {missing}")
98+
return True
99+
elif basis_id == 'cylinder':
100+
mandatory_keys = ['acyl', 'hcyl', 'nmaxfid', 'lmaxfid',
101+
'mmax', 'nmax', 'ncylodd', 'ncylnx',
102+
'ncylny', 'rnum', 'pnum', 'tnum', 'vflag', 'logr', 'cachename']
103+
missing = [key for key in mandatory_keys if key not in kwargs]
104+
if missing:
105+
raise KeyError(f"Missing mandatory keyword arguments missing: {missing}")
106+
return True
107+
else:
108+
raise AttributeError(f"basis id {basis_id} not found. Please chose between sphereSL or cylinder")
109+
110+
111+
def write_config(basis_id, float_fmt_rmin="{:.7f}", float_fmt_rmax="{:.3f}",
112+
float_fmt_rmapping="{:.3f}", **params):
113+
"""
114+
Create a YAML configuration file string for building a basis model.
115+
116+
Parameters
117+
----------
118+
basis_id : str
119+
Identifier of the basis model. Must be either 'sphereSL' or 'cylinder'.
120+
float_fmt_rmin : str, optional
121+
Format string for rmin (default ``"{:.7f}"``).
122+
float_fmt_rmax : str, optional
123+
Format string for rmax (default ``"{:.3f}"``).
124+
float_fmt_rmapping : str, optional
125+
Format string for rmapping (default ``"{:.3f}"``).
126+
**params : dict
127+
Additional keyword arguments required depending on the basis type:
128+
129+
- For ``sphereSL``:
130+
['lmax', 'nmax', 'rmapping', 'modelname', 'cachename']
131+
132+
- For ``cylinder``:
133+
['acyl', 'hcyl', 'nmaxfid', 'lmaxfid', 'mmax', 'nmax',
134+
'ncylodd', 'ncylnx', 'ncylny', 'rnum', 'pnum', 'tnum',
135+
'vflag', 'logr', 'cachename']
136+
137+
Returns
138+
-------
139+
str
140+
YAML configuration file contents.
141+
142+
Raises
143+
------
144+
KeyError
145+
If mandatory parameters for the given basis are missing.
146+
FileNotFoundError
147+
If ``modelname`` is required but cannot be opened.
148+
ValueError
149+
If the model file does not contain valid radius data.
150+
"""
151+
152+
check_basis_params(basis_id, **params)
153+
154+
155+
if basis_id == "sphereSL":
156+
modelname = params["modelname"]
157+
try:
158+
R = np.loadtxt(modelname, skiprows=3, usecols=0)
159+
except OSError as e:
160+
raise FileNotFoundError(f"Could not open model file '{modelname}'") from e
161+
if R.size == 0:
162+
raise ValueError(f"Model file '{modelname}' contains no radius data")
163+
164+
rmin, rmax, numr = R[0], R[-1], len(R)
165+
params["rmin"] = rmin
166+
params["rmax"] = rmax
167+
params["numr"] = numr
168+
169+
170+
config_dict = {
171+
"id": basis_id,
172+
"parameters": params
173+
}
174+
175+
176+
return yaml.dump(config_dict, sort_keys=False)
177+
178+
def make_basis(R, D, Mtotal, basis_params, modelname="test_model.txt", cachename='test_cache.txt'):
179+
"""
180+
Construct a basis from a given radial density profile.
181+
182+
Parameters
183+
----------
184+
R : array_like
185+
Radial grid points (e.g., radii at which density `D` is defined).
186+
D : array_like
187+
Density values corresponding to each radius in `R`.
188+
Mtotal : float, optional
189+
Total mass normalization (default is 1.0).
190+
basis_params : dict
191+
basis parameters e.g., basis_id, nmax, lmax
192+
193+
Returns
194+
-------
195+
basis : pyEXP.basis.Basis
196+
A basis object initialized with the given density model.
197+
198+
Notes
199+
-----
200+
- This function wraps `makemodel.makemodel` to generate a model from
201+
the supplied density profile and total mass.
202+
- It then builds a basis either spherical (`sphereSL`) or cylindrical using `EXPtools.make_config`
203+
and returns the corresponding `pyEXP` basis object.
204+
"""
205+
206+
R, D, _, _ = make_model(
207+
R, D, Mtotal=Mtotal,
208+
output_filename=modelname
209+
)
210+
211+
config = write_config(
212+
basis_id=basis_params['basis_id'],
213+
modelname=modelname,
214+
cachename=cachename
215+
)
216+
217+
basis = pyEXP.basis.Basis.factory(config)
218+
return basis

EXPtools/basis/makemodel.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import numpy as np
2+
3+
def write_table(tablename, radius, density, mass, potential, fmt="%.6e"):
4+
"""
5+
Write a table of radius, density, mass, and potential values to a text file.
6+
7+
Parameters
8+
----------
9+
tablename : str
10+
Output filename.
11+
radius, density, mass, potential : array-like
12+
Arrays of physical quantities, all with the same length.
13+
fmt : str, optional
14+
Format string for numerical values. Defaults to scientific notation with 6 decimals.
15+
16+
Notes
17+
-----
18+
Writes the table in the following format:
19+
! <tablename>
20+
! R D M P
21+
<Nrows>
22+
<radius> <density> <mass> <potential>
23+
"""
24+
# Convert inputs to NumPy arrays (for safety and performance)
25+
radius = np.asarray(radius)
26+
density = np.asarray(density)
27+
mass = np.asarray(mass)
28+
potential = np.asarray(potential)
29+
30+
# Stack data into a single 2D array for fast writing
31+
data = np.column_stack((radius, density, mass, potential))
32+
33+
header = f"! {tablename}\n! R D M P\n{len(radius)}"
34+
np.savetxt(tablename, data, fmt=fmt, header=header, comments="")
35+
36+
def make_model(radius, density, Mtotal, output_filename='', physical_units=False, verbose=True):
37+
"""
38+
Generate an EXP-compatible spherical basis function table.
39+
40+
Parameters
41+
----------
42+
radius : array-like
43+
Radii at which the density values are evaluated.
44+
density : array-like
45+
Density values corresponding to radius.
46+
Mtotal : float
47+
Total mass of the model, used for normalization.
48+
output_filename : str, optional
49+
Name of the output file to save the table. If empty, no file is written.
50+
physical_units : bool, optional
51+
If True, disables scaling and returns physical values (default: False).
52+
verbose : bool, optional
53+
If True, prints scaling information.
54+
55+
Returns
56+
-------
57+
result : dict
58+
Dictionary with the following keys:
59+
- 'radius' : ndarray
60+
Scaled radius values.
61+
- 'density' : ndarray
62+
Scaled density values.
63+
- 'mass' : ndarray
64+
Scaled enclosed mass values.
65+
- 'potential' : ndarray
66+
Scaled potential values.
67+
68+
Modified from Mike Petersen's original code:
69+
https://gist.github.com/michael-petersen/ec4f20641eedac8f63ec409c9cc65ed7
70+
"""
71+
72+
EPS_MASS = 1e-15
73+
EPS_R = 1e-10
74+
75+
Rmax = np.nanmax(radius)
76+
77+
mass = np.zeros_like(density)
78+
pwvals = np.zeros_like(density)
79+
80+
mass[0] = 1.e-15
81+
pwvals[0] = 0.
82+
83+
dr = np.diff(radius)
84+
85+
# Midpoint integration for enclosed mass and potential
86+
mass_contrib = 2.0 * np.pi * (
87+
radius[:-1]**2 * density[:-1] + radius[1:]**2 * density[1:]
88+
) * dr
89+
90+
pwvals_contrib = 2.0 * np.pi * (
91+
radius[:-1] * density[:-1] + radius[1:] * density[1:]
92+
) * dr
93+
94+
# Now cumulative sum to get the arrays
95+
mass = np.concatenate(([EPS_MASS], EPS_MASS + np.cumsum(mass_contrib)))
96+
pwvals = np.concatenate(([0.0], np.cumsum(pwvals_contrib)))
97+
98+
potential = -mass / (radius + EPS_R) - (pwvals[-1] - pwvals)
99+
100+
M0 = mass[-1]
101+
R0 = radius[-1]
102+
103+
Beta = (Mtotal / M0) * (R0 / Rmax)
104+
Gamma = np.sqrt((M0 * R0) / (Mtotal * Rmax)) * (R0 / Rmax)
105+
106+
if verbose:
107+
print(f"! Scaling: R = {Rmax} M = {Mtotal}")
108+
109+
rfac = Beta**-0.25 * Gamma**-0.5
110+
dfac = Beta**1.5 * Gamma
111+
mfac = Beta**0.75 * Gamma**-0.5
112+
pfac = Beta
113+
114+
if physical_units:
115+
rfac = dfac = mfac = pfac = 1.0
116+
117+
if verbose:
118+
print(f"Scaling factors: rfac = {rfac}, dfac = {dfac}, mfac = {mfac}, pfac = {pfac}")
119+
120+
if output_filename:
121+
write_table(
122+
output_filename,
123+
radius * rfac,
124+
density * dfac,
125+
mass * mfac,
126+
potential * pfac
127+
)
128+
129+
return {
130+
"radius": radius * rfac,
131+
"density": density * dfac,
132+
"mass": mass * mfac,
133+
"potential": potential * pfac,
134+
}
135+

0 commit comments

Comments
 (0)