diff --git a/glass/fields.py b/glass/fields.py index 9eed1dcc3..59b78f0fe 100644 --- a/glass/fields.py +++ b/glass/fields.py @@ -371,7 +371,7 @@ def _generate_grf( gls: Cls, nside: int, *, - ncorr: int | None = None, + ncorrs: list[int] | None = None, rng: np.random.Generator | None = None, ) -> Generator[NDArray[np.float64]]: """ @@ -416,43 +416,62 @@ def _generate_grf( ngls = len(gls) ngrf = nfields_from_nspectra(ngls) - # number of correlated fields if not specified - if ncorr is None: - ncorr = ngrf - 1 - # number of modes n = max((len(gl) for gl in gls), default=0) if n == 0: msg = "all gls are empty" raise ValueError(msg) - # generates the covariance matrix for the iterative sampler - cov = cls2cov(gls, n, ngrf, ncorr) - # working arrays for the iterative sampling z = np.zeros(n * (n + 1) // 2, dtype=np.complex128) - y = np.zeros((n * (n + 1) // 2, ncorr), dtype=np.complex128) + + blocks = [] + block_ns = [] + for j in range(len(gls)): + block = [gls[i][j : j + 1] for i in range(j, len(gls))] + blocks.append(block) + block_ns.append(len(gls) - j) + + # number of correlated fields if not specified + if ncorrs is None: + ncorrs = [ngrf - 1 for _ in range(len(blocks))] # generate the conditional normal distribution for iterative sampling - conditional_dist = iternorm(ncorr, cov, size=n) + conditional_dists = [] + for block, block_n, block_ncorr in zip(blocks, block_ns, ncorrs, strict=True): + # generate the covariance matrix of this block for the iterative sampler + block_cov = cls2cov(block, block_n, ngrf, block_ncorr) + # generate the conditional normal distribution for iterative sampling + conditional_dist = iternorm(block_ncorr, block_cov, size=block_n) + # store for parallel processing of all blocks + conditional_dists.append(conditional_dist) # sample the fields from the conditional distribution - for j, a, s in conditional_dist: + for results, ncorr in zip(*conditional_dists, ncorrs, strict=True): # standard normal random variates for alm # sample real and imaginary parts, then view as complex number rng.standard_normal(n * (n + 1), np.float64, z.view(np.float64)) + # concatenate individual updates into one update + s = np.concatenate([block_s for _, _, block_s in results]) + # scale by standard deviation of the conditional distribution # variance is distributed over real and imaginary part alm = _multalm(z, s) # add the mean of the conditional distribution + y = np.zeros((n * (n + 1) // 2, ncorr), dtype=np.complex128) + a = np.concatenate([block_a for _, block_a, _ in results]) for i in range(ncorr): alm += _multalm(y[:, i], a[:, i]) + for i in range(ncorr): + # calculate ks + pass + # store the standard normal in y array at the indicated index - if j is not None: - y[:, j] = z + if results[0] is not None: + y[:, results[0]] = z[k1:k2] alm = _glass_to_healpix_alm(alm)