diff --git a/pymaster/utils.py b/pymaster/utils.py index 4c2a265b..8043e9d6 100644 --- a/pymaster/utils.py +++ b/pymaster/utils.py @@ -376,7 +376,8 @@ def mask_apodization_flat(mask_in, lx, ly, aposize, apotype="C1"): return mask_apo_flat.reshape([ny, nx]) -def synfast_spherical(nside, cls, spin_arr, beam=None, seed=-1, wcs=None): +def synfast_spherical(nside, cls, spin_arr, beam=None, seed=-1, + wcs=None, lmax=None): """ Generates a full-sky Gaussian random field according to a given \ power spectrum. This function should produce outputs similar to \ @@ -422,20 +423,25 @@ def synfast_spherical(nside, cls, spin_arr, beam=None, seed=-1, wcs=None): raise ValueError("Given your WCS, the map wouldn't cover the " "whole sphere exactly") + if lmax is None: + lmax = wt.get_lmax() + spin_arr = np.array(spin_arr).astype(np.int32) nfields = len(spin_arr) if np.any(spin_arr < 0): raise ValueError("Spins must be positive") - nmaps = int(1 * np.sum(spin_arr == 0) + 2 * np.sum(spin_arr != 0)) + nmap_arr = np.array([1+int(s != 0) for s in spin_arr]) + map_first = np.concatenate([[0], np.cumsum(nmap_arr)[:-1]]) + nmaps = np.sum(nmap_arr) ncls = (nmaps * (nmaps + 1)) // 2 if ncls != len(cls): raise ValueError( - "Must provide all Cls necessary to simulate all " - "fields (%d)." % ncls - ) - lmax = len(cls[0]) - 1 + f"Must provide all Cls necessary to simulate all field ({ncls}).") + lmax_cls = len(cls[0]) - 1 + lmax = min(lmax_cls, lmax) + ainfo = AlmInfo(lmax) if beam is None: beam = np.ones([nfields, lmax + 1]) @@ -447,14 +453,16 @@ def synfast_spherical(nside, cls, spin_arr, beam=None, seed=-1, wcs=None): "The beam should have as many multipoles as the power spectrum" ) - data = lib.synfast_new(wt.is_healpix, wt.nside, wt.nx, wt.ny, - wt.d_phi, wt.d_theta, wt.phi0, wt.theta_max, - spin_arr, seed, cls, beam, nmaps * wt.npix) + # Note that, if `new=False` stops being allowed in healpy, we'll need + # to change the Cl ordering. + alms = np.array(hp.synalm(cls, lmax=lmax, mmax=lmax, new=False)) + maps = np.concatenate([alm2map(alms[i0:i0+n], s, wt.minfo, ainfo) + for i0, n, s in zip(map_first, nmap_arr, spin_arr)]) if wt.is_healpix: - maps = data.reshape([nmaps, wt.npix]) + maps = maps.reshape([nmaps, wt.npix]) else: - maps = data.reshape([nmaps, wt.ny, wt.nx]) + maps = maps.reshape([nmaps, wt.ny, wt.nx]) if wt.flip_th: maps = maps[:, ::-1, :] if wt.flip_ph: