Conversation
|
@ASKabalan just a few small comments on good practices:
|
| import healpy as hp | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import jax_healpy as jhp |
There was a problem hiding this comment.
Please add the dependency to the pyrpoject.toml
There was a problem hiding this comment.
-
I would suggest changing the approach to bin directly the particles on the spherical maps
-
Please remove from this PR things that are not related to spherical lensing, you can put them into a separate PR
-
We need to add validation of the spherical lensing maps from somewhere, maybe comparing to glass?
|
Ignore last message ^^ Claude was a bit too enthusiastic |
Branch History
You are right .. the other functionalities come from other branch since I created this one out of the fix-sharding branch
Yeah I agree .. I think I found a way to do it better in the litterature
Yes I figured out that my D_R goes outwards which means that the first projection includes only the corner of the box since it is longer that the box depth Line 33 in 09d8228 This explains that wierd circles .. I just need to integrate the D_R backwards
Can I "slightly" rewrite history of this branch then 🙏 |
|
@EiffL Regarding the cache You can easily reproduce the issue with the Spherical RayTracing notebook If the cache is written into in 2 different places in the code and they cannot be in the same HLO function (as in one is inside a lax.while and the other one is at the start of the function) Typically, we want to compute growth factors and radial distances in different places for example to do spherical projection but also to know at which scale factor we want to integrate .. this causes the leak. My solution was caching inside the HLO itself using jax.ensure_compile_time_eval however this needed a little bit more thinking since the scale factor that we use is a tracer and this needs to be used only on concrete values. I still think that this is the best solution. Another solution is io.callback as jakevdp said here But I am not a fan of callbacks (since they are executed on CPU) I will probably look more into it .. however I will keep the desactivation of the cache for now in this PR so I can focus on the spherical stuff |
|
Also one more thing. I think that I will just index the particles from the sharded 3D density and let JAX take the wheel |
|
No need to rewrite history, it's fine. you can just open a new branch to keep your edits on these files. And we can reset the files in this barnch to the version on main with checkout -- |
|
let's discuss the caching in a different thread, I can tell you what my logic was. |
|
and regarding boundary conditions, no it's trivial, you just do a modulo of the box volume on the particle positions before you paint them. But also we shouldn't get to the boundaries of the volume if we do things correctly. Our survey should fit well within the volume |
|
Ok, so here is a simple spherical binning code: from functools import partial
import jax
import jax.numpy as jnp
import jax_healpy as jhp
@partial(jax.jit, static_argnames=('nside',))
def paint_particles_spherical(positions, nside, observer_position,
R_min, R_max, box_size, mesh_shape, weights=None):
"""
Directly bin particles onto HEALPix spherical maps without intermediate
3D Cartesian mesh. This avoids double binning artifacts.
Parameters
----------
positions : ndarray, shape (..., 3)
Particle positions in simulation coordinates
nside : int
HEALPix nside parameter
observer_position : ndarray, shape (3,)
Observer position in comoving coordinates
R_min, R_max : float
Minimum and maximum comoving distance range to include
box_size : float
Size of the simulation box in physical units
mesh_shape : tuple
Shape of the simulation mesh (nx, ny, nz)
weights : ndarray, optional
Particle weights (default: uniform weights)
Returns
-------
healpix_map : ndarray
HEALPix density map
"""
if weights is None:
weights = jnp.ones(positions.shape[:-1])
# Convert particle positions from simulation coordinates to physical coordinates
# by scaling with box_size and mesh_shape
positions = positions * jnp.array(box_size) / jnp.array(mesh_shape)
# Compute relative positions from observer
rel_positions = positions - jnp.asarray(observer_position)
# Convert to spherical coordinates
x, y, z = rel_positions[..., 0], rel_positions[..., 1], rel_positions[..., 2]
# Comoving distance from observer
r = jnp.sqrt(x**2 + y**2 + z**2)
# Apply distance cuts
distance_mask = (r >= R_min) & (r <= R_max)
# Compute angular coordinates (theta, phi in spherical coordinates)
# theta = polar angle from z-axis, phi = azimuthal angle
theta = jnp.arccos(jnp.clip(z / (r + 1e-10), -1, 1))
phi = jnp.arctan2(y, x)
# Convert to HEALPix pixel indices
pixels = jhp.ang2pix(nside, theta.flatten(), phi.flatten())
# Apply distance mask to weights
masked_weights = (weights * distance_mask).flatten()
# Bin particles into HEALPix pixels
npix = jhp.nside2npix(nside)
healpix_map = jnp.bincount(pixels, weights=masked_weights, length=npix)
return healpix_mapThis will generate these nice density maps: Full code here: https://gist.github.com/EiffL/257993c1fa18ae47763b0577d8e2d294 |
|
I'll stop playing with this for now ^^ I let you integrate it with lensing and adding validation tests against glass to make sure we are getting consistent Born raytracing with glass at least |
|
Note that I worked out some points of comparison against other lensing codes here: https://github.com/EiffL/RayTracingTests/blob/main/results.ipynb GLASS in particular, but also Dorian which computes the full amplification matrix and deflection field, in addition to Born approximation. |
|
I am going to take some time to understand this code A in here has the shear components as y2 = -A12 and y1 = (A11 + A22)/ 2 I don't know what parallel transport is so I am going to read the paper I suggest putting togher a plan I saw that you made JAXRT which I assume does ray tracing . and I suggest moving all lensing to jaxRT jaxDecomp for FFTs , jaxpmesh for painting and computing spectrum, cl etc .. jaxpm to generate densitymaps + nbody, jaxrt for ray tracing. Some people (for example SBI people) already have density maps and wan't just to create Kappas and shear so they can use jaxrt independantly) Let me know what do you think |
|
Don't worry too much about the full ray tracing for now. I would suggest, just add spherical Born here, it's very simple, not much more code than what we already have. We'll see about a full fledged raytracing library later. Once you have spherical born lensing here, we can focus on what matters more: testing the resolution/volume we need for DES and testing whether we can detect that Born lensing is not going to be enough. And again, do not worry about E/B mode decomposition or anything like that for now. Let's do full sky first and test before doing any more complicated implementation. We want to go fast to a point where we can test whether a simple lensing implementation has issues or not. We do not want to overengineer first. |
|
The test that we will want to be able to do is to generate a simulation, do full raytracing with Dorian, and then try to analyse it with our simple Born raytracing. This will tell us if born approximation matter or not for the type of analysis we want to do, at the resolution we want to do it with |
|
I think this works now check notebook https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/41-spherical-lensing/notebooks/07-SphericalRayTracing.ipynb |
|
I will start doing cl analysis using https://github.com/EiffL/RayTracingTests/blob/d4c8351c16012eb5f4f6011bb71640cfd4759a93/power_spectrum_analysis.py |
|
So, how is it going? |
|
So, how is it going :-) ? |
|
I am almost there I checked with this https://www.astro.ucla.edu/%7Ewright/CosmoCalc.html I will add a plot of CL in the conversation |
|
I am going to make sure that the Flat2D works aswell and then move on to higher box sizes on HPC |
|
This is becoming a dev branch I will split it to multiple branches when I am done |
|
Quick update it seems that the spherical convergence is working correctly Not the case for Flat 2D Maps Flat-Sky Analysis Summary:Field size: 9.6° × 9.6° The agreement is actually falling short if I increase the resolution I will try multi gpu 1024 now |
|
This is great progress @ASKabalan ! Remember to work in incremental atomic PRs, we can move this discussion to the lensing issue, because here you are testing the impact of different resolutions, but that's a different question from testing if your lensing kernel is working. I would recommend making a first PR where you implement the spherical binning and lensing operation, and include a comparison against glass as a unit test. This test would be independent from the resolution of the simulation, the question it answers is whether you get the right convergence for a given set of density maps. |
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
… latests tecnhiques.
…spherical-lensing
…ate-shardmap-import
…spherical-lensing
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 11 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pip install -r requirements-test.txt --no-build-isolation | ||
| # Install additional test dependencies | ||
| pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass' | ||
| # Install packages with test dependencies |
There was a problem hiding this comment.
The comment 'Install packages with test dependencies' appears twice (lines 54 and 62) with different actual operations, which is confusing. The second occurrence at line 62 should be removed as it's followed by a simple echo statement, not another pip install command.
| # Install packages with test dependencies |
.github/workflows/tests.yml
Outdated
| pip install -e .[test] | ||
| # Install build dependencies | ||
| pip install setuptools cython mpi4py | ||
| # Install test requirements with no-build-isolation for faster builds | ||
| # Install test requirements with no-build-isolation for PFFT | ||
| pip install -r requirements-test.txt --no-build-isolation | ||
| # Install additional test dependencies | ||
| pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass' |
There was a problem hiding this comment.
The installation order may cause issues: pip install -e .[test] is run before installing requirements-test.txt which contains numpy==2.2.6. This could lead to dependency conflicts if the project dependencies have numpy version constraints. Consider installing requirements-test.txt before running pip install -e .[test] to establish the numpy version first.
.github/workflows/tests.yml
Outdated
| # Install test requirements with no-build-isolation for PFFT | ||
| pip install -r requirements-test.txt --no-build-isolation | ||
| # Install additional test dependencies | ||
| pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass' |
There was a problem hiding this comment.
pytest is being installed redundantly here. It's already listed as a test dependency in pyproject.toml under [project.optional-dependencies].test and would have been installed by the earlier pip install -e .[test] command on line 55.
| pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass' | |
| pip install diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass' |
| safe_rho_mean = jnp.where(rho_mean == 0, eps, rho_mean) | ||
| delta = density_planes / safe_rho_mean - 1 |
There was a problem hiding this comment.
When rho_mean is exactly 0, safe_rho_mean is set to eps, making delta = density_planes/eps - 1. If density_planes contains zeros (which is expected when rho_mean is 0), this results in -1.0. However, if density_planes contains non-zero values while rho_mean is 0, this will create artificially large delta values. Consider using jnp.where(rho_mean == 0, 0.0, density_planes / rho_mean - 1) to handle the zero-mean case more directly.
| safe_rho_mean = jnp.where(rho_mean == 0, eps, rho_mean) | |
| delta = density_planes / safe_rho_mean - 1 | |
| # Use jnp.where to set delta=0 when rho_mean==0, avoiding artificially large values | |
| delta = jnp.where(rho_mean == 0, 0.0, density_planes / rho_mean - 1) |
| cosmo, jc.utils.z2a(z_max)).squeeze() | ||
| cosmo._workspace = {} | ||
|
|
||
| factors = jnp.clip(jnp.array(config['observer_position_in_box']), 0.0, 0.5) |
There was a problem hiding this comment.
The calculation of 'factors' for box_size adjustment is not documented. The logic clips to [0, 0.5], then applies a transformation that's not immediately clear. Add a comment explaining the purpose and formula: e.g., 'Adjust box size based on observer position to ensure adequate simulation volume in all directions'.
| factors = jnp.clip(jnp.array(config['observer_position_in_box']), 0.0, 0.5) | |
| factors = jnp.clip(jnp.array(config['observer_position_in_box']), 0.0, 0.5) | |
| # Adjust box size based on observer position to ensure adequate simulation volume in all directions. | |
| # The formula expands the box so that, for an observer offset from the center, there is enough space | |
| # in both directions (towards and away from the observer) to reach the maximum comoving distance. | |
| # Specifically, for each axis, the box size is scaled by: | |
| # 1.0 + 2.0 * min(factor, 1.0 - factor) | |
| # where 'factor' is the fractional position of the observer in the box (clipped to [0, 0.5]). |
|
|
||
| # Setup Glass cosmology to match jax_cosmo parameters | ||
| h = cosmo.h | ||
| omega_m = cosmo.Omega_c + cosmo.Omega_b |
There was a problem hiding this comment.
Variable omega_m is not used.
| omega_m = cosmo.Omega_c + cosmo.Omega_b |
|
|
||
| from jaxpm.painting import cic_paint_2d | ||
| from jaxpm.distributed import uniform_particles | ||
| from jaxpm.painting import cic_paint, cic_paint_2d, cic_paint_dx |
There was a problem hiding this comment.
Import of 'cic_paint' is not used.
Import of 'cic_paint_dx' is not used.
| from jaxpm.painting import cic_paint, cic_paint_2d, cic_paint_dx | |
| from jaxpm.painting import cic_paint_2d |
| @@ -0,0 +1,141 @@ | |||
| from jaxpm.growth import E, Gf, dGfa, gp | |||
There was a problem hiding this comment.
Import of 'gp' is not used.
| from jaxpm.growth import E, Gf, dGfa, gp | |
| from jaxpm.growth import E, Gf, dGfa |
| @@ -0,0 +1,435 @@ | |||
| import os | |||
There was a problem hiding this comment.
Import of 'os' is not used.
| import os |
| from cosmology.compat.camb import Cosmology | ||
| from diffrax import (ConstantStepSize, ODETerm, SaveAt, SemiImplicitEuler, | ||
| diffeqsolve) | ||
| from numpy.testing import assert_allclose |
There was a problem hiding this comment.
Import of 'assert_allclose' is not used.
| from numpy.testing import assert_allclose |







No description provided.