Skip to content

41 spherical lensing#42

Open
ASKabalan wants to merge 39 commits intomainfrom
41-spherical-lensing
Open

41 spherical lensing#42
ASKabalan wants to merge 39 commits intomainfrom
41-spherical-lensing

Conversation

@ASKabalan
Copy link
Member

No description provided.

@ASKabalan ASKabalan linked an issue Jun 28, 2025 that may be closed by this pull request
@ASKabalan ASKabalan self-assigned this Jun 28, 2025
@EiffL
Copy link
Member

EiffL commented Jun 28, 2025

@ASKabalan just a few small comments on good practices:

  • remember to add a description of what your PR is doing in the initial post. Here for instance, if you are adding some lensing functionality, you can show a plot or some sort of results of validation tests to give the reviewer a sense of how well the added feature is working.

  • it is good practice to keep your PRs atomistic, around well defined scope. Here for instance I see modifications to growth.py, a simplectic integrator, and some lensing utilities. These are pretty unrelated things, it would be better to split into several PRs

import healpy as hp
import jax
import jax.numpy as jnp
import jax_healpy as jhp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the dependency to the pyrpoject.toml

@EiffL
Copy link
Member

EiffL commented Jun 28, 2025

If I understand correctly what you are proposing here, you first paint the particles onto a cartesian mesh, then you reproject that 3d density onto spherical maps.

I don't think that's a good way to do it, because you combine the effect of 2 binning schemes.

I would suggest instead to bin the particles directly onto the spherical maps? that would be a much better approach.

Also, these plots are very sus
image
we can see a very weird pattern on these maps that probably has something to do iwth the 3d box

Copy link
Member

@EiffL EiffL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I would suggest changing the approach to bin directly the particles on the spherical maps

  • Please remove from this PR things that are not related to spherical lensing, you can put them into a separate PR

  • We need to add validation of the spherical lensing maps from somewhere, maybe comparing to glass?

@EiffL
Copy link
Member

EiffL commented Jun 29, 2025

Ignore last message ^^ Claude was a bit too enthusiastic

@ASKabalan
Copy link
Member Author

ASKabalan commented Jun 29, 2025

@EiffL

Branch History

remember to add a description of what your PR is doing in the initial post. Here for instance, if you are adding some lensing functionality, you can show a plot or some sort of results of validation tests to give the reviewer a sense of how well the added feature is working.

it is good practice to keep your PRs atomistic, around well defined scope. Here for instance I see modifications to growth.py, a simplectic integrator, and some lensing utilities. These are pretty unrelated things, it would be better to split into several PRs

You are right .. the other functionalities come from other branch since I created this one out of the fix-sharding branch
I can rewrite history and make sure that it matches that of main .. but we all know what do you think of that 🥲

I would suggest instead to bin the particles directly onto the spherical maps? that would be a much better approach.

Yeah I agree .. I think I found a way to do it better in the litterature

we can see a very weird pattern on these maps that probably has something to do iwth the 3d box

Yes I figured out that my D_R goes outwards which means that the first projection includes only the corner of the box since it is longer that the box depth

R_s = jnp.arange(0, d_R, 1.0) + R

This explains that wierd circles .. I just need to integrate the D_R backwards

Please remove from this PR things that are not related to spherical lensing, you can put them into a separate PR

Can I "slightly" rewrite history of this branch then 🙏

@ASKabalan
Copy link
Member Author

ASKabalan commented Jun 29, 2025

@EiffL Regarding the cache

You can easily reproduce the issue with the Spherical RayTracing notebook

If the cache is written into in 2 different places in the code and they cannot be in the same HLO function (as in one is inside a lax.while and the other one is at the start of the function)

Typically, we want to compute growth factors and radial distances in different places for example to do spherical projection but also to know at which scale factor we want to integrate .. this causes the leak.

My solution was caching inside the HLO itself using jax.ensure_compile_time_eval however this needed a little bit more thinking since the scale factor that we use is a tracer and this needs to be used only on concrete values.

I still think that this is the best solution.

Another solution is io.callback as jakevdp said here

But I am not a fan of callbacks (since they are executed on CPU)
And also I am afraid that this is not differentiable

I will probably look more into it .. however I will keep the desactivation of the cache for now in this PR so I can focus on the spherical stuff

@ASKabalan
Copy link
Member Author

Also one more thing.
I decided to 3D paint before binning because the cic painting handles boundary condtions.
Boundary conditions for spherical painting is going to be hard.

I think that I will just index the particles from the sharded 3D density and let JAX take the wheel

@EiffL
Copy link
Member

EiffL commented Jun 29, 2025

No need to rewrite history, it's fine. you can just open a new branch to keep your edits on these files.

And we can reset the files in this barnch to the version on main with checkout --

@EiffL
Copy link
Member

EiffL commented Jun 29, 2025

let's discuss the caching in a different thread, I can tell you what my logic was.

@EiffL
Copy link
Member

EiffL commented Jun 29, 2025

and regarding boundary conditions, no it's trivial, you just do a modulo of the box volume on the particle positions before you paint them.

But also we shouldn't get to the boundaries of the volume if we do things correctly. Our survey should fit well within the volume

@EiffL
Copy link
Member

EiffL commented Jun 29, 2025

Ok, so here is a simple spherical binning code:

from functools import partial

import jax
import jax.numpy as jnp
import jax_healpy as jhp



@partial(jax.jit, static_argnames=('nside',))
def paint_particles_spherical(positions, nside, observer_position,
                             R_min, R_max, box_size, mesh_shape,  weights=None):
    """
    Directly bin particles onto HEALPix spherical maps without intermediate 
    3D Cartesian mesh. This avoids double binning artifacts.
    
    Parameters
    ----------
    positions : ndarray, shape (..., 3)
        Particle positions in simulation coordinates
    nside : int
        HEALPix nside parameter
    observer_position : ndarray, shape (3,)
        Observer position in comoving  coordinates
    R_min, R_max : float
        Minimum and maximum comoving distance range to include
    box_size : float
        Size of the simulation box in physical units
    mesh_shape : tuple
        Shape of the simulation mesh (nx, ny, nz)
    weights : ndarray, optional
        Particle weights (default: uniform weights)
        
    Returns
    -------
    healpix_map : ndarray
        HEALPix density map
    """
    if weights is None:
        weights = jnp.ones(positions.shape[:-1])
    
    # Convert particle positions from simulation coordinates to physical coordinates
    # by scaling with box_size and mesh_shape
    positions = positions * jnp.array(box_size) / jnp.array(mesh_shape)

    # Compute relative positions from observer
    rel_positions = positions - jnp.asarray(observer_position)
    
    # Convert to spherical coordinates
    x, y, z = rel_positions[..., 0], rel_positions[..., 1], rel_positions[..., 2]
    
    # Comoving distance from observer
    r = jnp.sqrt(x**2 + y**2 + z**2)
    
    # Apply distance cuts
    distance_mask = (r >= R_min) & (r <= R_max)
    
    # Compute angular coordinates (theta, phi in spherical coordinates)
    # theta = polar angle from z-axis, phi = azimuthal angle  
    theta = jnp.arccos(jnp.clip(z / (r + 1e-10), -1, 1))
    phi = jnp.arctan2(y, x)
    
    # Convert to HEALPix pixel indices
    pixels = jhp.ang2pix(nside, theta.flatten(), phi.flatten())
    
    # Apply distance mask to weights
    masked_weights = (weights * distance_mask).flatten()
    
    # Bin particles into HEALPix pixels
    npix = jhp.nside2npix(nside)
    healpix_map = jnp.bincount(pixels, weights=masked_weights, length=npix)
    
    return healpix_map

This will generate these nice density maps:
image

Full code here: https://gist.github.com/EiffL/257993c1fa18ae47763b0577d8e2d294

@EiffL
Copy link
Member

EiffL commented Jun 29, 2025

I'll stop playing with this for now ^^ I let you integrate it with lensing and adding validation tests against glass to make sure we are getting consistent Born raytracing with glass at least

@EiffL
Copy link
Member

EiffL commented Jun 30, 2025

Note that I worked out some points of comparison against other lensing codes here: https://github.com/EiffL/RayTracingTests/blob/main/results.ipynb

GLASS in particular, but also Dorian which computes the full amplification matrix and deflection field, in addition to Born approximation.

@ASKabalan
Copy link
Member Author

I am going to take some time to understand this code
But If I understand correctly

https://github.com/EiffL/RayTracingTests/blob/d4c8351c16012eb5f4f6011bb71640cfd4759a93/dorian_raytracing.py#L209

A in here has the shear components as y2 = -A12 and y1 = (A11 + A22)/ 2

I don't know what parallel transport is so I am going to read the paper

I suggest putting togher a plan

I saw that you made JAXRT which I assume does ray tracing .
IMO jaxpm should go as far as density maps

and I suggest moving all lensing to jaxRT
I can still try to code here but let's maybe discuss how would this ecosystem look like

jaxDecomp for FFTs , jaxpmesh for painting and computing spectrum, cl etc .. jaxpm to generate densitymaps + nbody, jaxrt for ray tracing.

Some people (for example SBI people) already have density maps and wan't just to create Kappas and shear so they can use jaxrt independantly)

Let me know what do you think

@EiffL
Copy link
Member

EiffL commented Jun 30, 2025

Don't worry too much about the full ray tracing for now.

I would suggest, just add spherical Born here, it's very simple, not much more code than what we already have. We'll see about a full fledged raytracing library later.

Once you have spherical born lensing here, we can focus on what matters more: testing the resolution/volume we need for DES and testing whether we can detect that Born lensing is not going to be enough.

And again, do not worry about E/B mode decomposition or anything like that for now. Let's do full sky first and test before doing any more complicated implementation.

We want to go fast to a point where we can test whether a simple lensing implementation has issues or not. We do not want to overengineer first.

@EiffL
Copy link
Member

EiffL commented Jun 30, 2025

The test that we will want to be able to do is to generate a simulation, do full raytracing with Dorian, and then try to analyse it with our simple Born raytracing. This will tell us if born approximation matter or not for the type of analysis we want to do, at the resolution we want to do it with

@ASKabalan
Copy link
Member Author

I think this works now

image

check notebook https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/41-spherical-lensing/notebooks/07-SphericalRayTracing.ipynb

@ASKabalan
Copy link
Member Author

I will start doing cl analysis using https://github.com/EiffL/RayTracingTests/blob/d4c8351c16012eb5f4f6011bb71640cfd4759a93/power_spectrum_analysis.py

@EiffL
Copy link
Member

EiffL commented Jul 1, 2025

So, how is it going?

@ASKabalan
Copy link
Member Author

not going to lie I lost sometime playing with Claude 😄

image

But in short the binning was not a great idea ..
The power is a bit lower than needed
I am going to do a 1024^3 and 512 nside on jean zay with a 90 step simulation

@EiffL
Copy link
Member

EiffL commented Jul 3, 2025

So, how is it going :-) ?

@ASKabalan
Copy link
Member Author

I am almost there
Was struggeling with units of jc.background.radial_comoving_distance
it seems to be Mpc/h
But sometimes people divide by cosmo.h to get it in Mpc
for example https://github.com/LSSTDESC/glass-jax/blob/914451146adaba590eacc23422ce443ebe72a5a3/glass/_src/shells.py#L100

I checked with this https://www.astro.ucla.edu/%7Ewright/CosmoCalc.html
jc.background.radial_comoving_distance(cosmo, a) / cosmo.h is in Mpc

I will add a plot of CL in the conversation

@ASKabalan
Copy link
Member Author

image

I think that this is ok

I still have a factor 2 at low ell for some reason.

The pixel window correction destroys small scale but I think that this is what it is supposed to do
But does not improve large scales

@ASKabalan
Copy link
Member Author

https://github.com/DifferentiableUniverseInitiative/JaxPM/blob/41-spherical-lensing/notebooks/08-Kappa-Comparison-vs-Theory.ipynb

I am going to make sure that the Flat2D works aswell and then move on to higher box sizes on HPC

@ASKabalan
Copy link
Member Author

@EiffL

This is becoming a dev branch

I will split it to multiple branches when I am done

@ASKabalan
Copy link
Member Author

Quick update

it seems that the spherical convergence is working correctly

image

Analysis Parameters:
  Source redshift: z = 0.2
  Box size: 1688 Mpc/h
  HEALPix resolution: nside = 128
  Mesh sizes tested: [128, 256, 512]

Key Results:
  128³ mesh: mean agreement = 950.0% of theory
  256³ mesh: mean agreement = 167.5% of theory
  512³ mesh: mean agreement = 120.9% of theory

Not the case for Flat 2D Maps

image

Flat-Sky Analysis Summary:

Field size: 9.6° × 9.6°
Field resolution: 128 × 128 pixels
Angular resolution: 4.5 arcmin/pixel

128³ mesh:
  κ range: [-1.01e-03, 2.77e-03]
  κ RMS: 6.70e-04
  Agreement with theory: 16.9%

256³ mesh:
  κ range: [-9.26e-04, 2.35e-03]
  κ RMS: 6.41e-04
  Agreement with theory: 13.9%

512³ mesh:
  κ range: [-9.47e-04, 2.01e-03]
  κ RMS: 4.57e-04
  Agreement with theory: 8.1%

The agreement is actually falling short if I increase the resolution
Must be something wrong with the density creation function.

I will try multi gpu 1024 now

@EiffL
Copy link
Member

EiffL commented Jul 4, 2025

This is great progress @ASKabalan !

Remember to work in incremental atomic PRs, we can move this discussion to the lensing issue, because here you are testing the impact of different resolutions, but that's a different question from testing if your lensing kernel is working.

I would recommend making a first PR where you implement the spherical binning and lensing operation, and include a comparison against glass as a unit test. This test would be independent from the resolution of the simulation, the question it answers is whether you get the right convergence for a given set of density maps.

@ASKabalan ASKabalan requested a review from Copilot October 7, 2025 22:03
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@ASKabalan ASKabalan requested a review from Copilot November 4, 2025 12:56
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 11 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

pip install -r requirements-test.txt --no-build-isolation
# Install additional test dependencies
pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass'
# Install packages with test dependencies
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment 'Install packages with test dependencies' appears twice (lines 54 and 62) with different actual operations, which is confusing. The second occurrence at line 62 should be removed as it's followed by a simple echo statement, not another pip install command.

Suggested change
# Install packages with test dependencies

Copilot uses AI. Check for mistakes.
Comment on lines 55 to 61
pip install -e .[test]
# Install build dependencies
pip install setuptools cython mpi4py
# Install test requirements with no-build-isolation for faster builds
# Install test requirements with no-build-isolation for PFFT
pip install -r requirements-test.txt --no-build-isolation
# Install additional test dependencies
pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass'
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The installation order may cause issues: pip install -e .[test] is run before installing requirements-test.txt which contains numpy==2.2.6. This could lead to dependency conflicts if the project dependencies have numpy version constraints. Consider installing requirements-test.txt before running pip install -e .[test] to establish the numpy version first.

Copilot uses AI. Check for mistakes.
# Install test requirements with no-build-isolation for PFFT
pip install -r requirements-test.txt --no-build-isolation
# Install additional test dependencies
pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass'
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest is being installed redundantly here. It's already listed as a test dependency in pyproject.toml under [project.optional-dependencies].test and would have been installed by the earlier pip install -e .[test] command on line 55.

Suggested change
pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass'
pip install diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass'

Copilot uses AI. Check for mistakes.
Comment on lines +170 to +171
safe_rho_mean = jnp.where(rho_mean == 0, eps, rho_mean)
delta = density_planes / safe_rho_mean - 1
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When rho_mean is exactly 0, safe_rho_mean is set to eps, making delta = density_planes/eps - 1. If density_planes contains zeros (which is expected when rho_mean is 0), this results in -1.0. However, if density_planes contains non-zero values while rho_mean is 0, this will create artificially large delta values. Consider using jnp.where(rho_mean == 0, 0.0, density_planes / rho_mean - 1) to handle the zero-mean case more directly.

Suggested change
safe_rho_mean = jnp.where(rho_mean == 0, eps, rho_mean)
delta = density_planes / safe_rho_mean - 1
# Use jnp.where to set delta=0 when rho_mean==0, avoiding artificially large values
delta = jnp.where(rho_mean == 0, 0.0, density_planes / rho_mean - 1)

Copilot uses AI. Check for mistakes.
cosmo, jc.utils.z2a(z_max)).squeeze()
cosmo._workspace = {}

factors = jnp.clip(jnp.array(config['observer_position_in_box']), 0.0, 0.5)
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation of 'factors' for box_size adjustment is not documented. The logic clips to [0, 0.5], then applies a transformation that's not immediately clear. Add a comment explaining the purpose and formula: e.g., 'Adjust box size based on observer position to ensure adequate simulation volume in all directions'.

Suggested change
factors = jnp.clip(jnp.array(config['observer_position_in_box']), 0.0, 0.5)
factors = jnp.clip(jnp.array(config['observer_position_in_box']), 0.0, 0.5)
# Adjust box size based on observer position to ensure adequate simulation volume in all directions.
# The formula expands the box so that, for an observer offset from the center, there is enough space
# in both directions (towards and away from the observer) to reach the maximum comoving distance.
# Specifically, for each axis, the box size is scaled by:
# 1.0 + 2.0 * min(factor, 1.0 - factor)
# where 'factor' is the fractional position of the observer in the box (clipped to [0, 0.5]).

Copilot uses AI. Check for mistakes.

# Setup Glass cosmology to match jax_cosmo parameters
h = cosmo.h
omega_m = cosmo.Omega_c + cosmo.Omega_b
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable omega_m is not used.

Suggested change
omega_m = cosmo.Omega_c + cosmo.Omega_b

Copilot uses AI. Check for mistakes.

from jaxpm.painting import cic_paint_2d
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_2d, cic_paint_dx
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'cic_paint' is not used.
Import of 'cic_paint_dx' is not used.

Suggested change
from jaxpm.painting import cic_paint, cic_paint_2d, cic_paint_dx
from jaxpm.painting import cic_paint_2d

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,141 @@
from jaxpm.growth import E, Gf, dGfa, gp
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'gp' is not used.

Suggested change
from jaxpm.growth import E, Gf, dGfa, gp
from jaxpm.growth import E, Gf, dGfa

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,435 @@
import os
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'os' is not used.

Suggested change
import os

Copilot uses AI. Check for mistakes.
from cosmology.compat.camb import Cosmology
from diffrax import (ConstantStepSize, ODETerm, SaveAt, SemiImplicitEuler,
diffeqsolve)
from numpy.testing import assert_allclose
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'assert_allclose' is not used.

Suggested change
from numpy.testing import assert_allclose

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

spherical-lensing

3 participants