diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index dc9880f9..23e5588a 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -8,6 +8,7 @@ from jax_galsim.bessel import kv from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue +from jax_galsim.core.interpolate import akima_interp, akima_interp_coeffs from jax_galsim.core.math import safe_sqrt from jax_galsim.core.utils import ( ensure_hashable, @@ -302,26 +303,72 @@ def _xValue(self, pos): rsq > self._maxRrD_sq, 0.0, self._norm * jnp.power(1.0 + rsq, -self.beta) ) - def _kValue_untrunc(self, k): + @staticmethod + @jax.jit + def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm, _r0): """Non truncated version of _kValue""" - k_msk = jnp.where(k > 0, k, 1.0) + msk = k > 0 + kr0_msk = jnp.where(msk, k, 1.0) * _r0 return jnp.where( - k > 0, - self._knorm_bis - * jnp.power(k_msk, self.beta - 1.0) - * _Knu(self.beta - 1.0, k_msk), + msk, + _knorm_bis * jnp.power(kr0_msk, beta - 1.0) * _Knu(beta - 1.0, kr0_msk), + _knorm, + ) + + @staticmethod + @jax.jit + def _kValue_untrunc_asymp_func(beta, k, _knorm_bis, _r0): + kr0 = k * _r0 + return ( + _knorm_bis + * jnp.power(kr0, beta - 1.0) + * jnp.exp(-kr0) + * jnp.sqrt(jnp.pi / 2 / kr0) + ) + + def _kValue_untrunc_interp_coeffs(self): + # MRB: this number of points gets the tests to pass + # I did not investigate further. + n_pts = 700 + k_min = 0 + k_max = self._maxk + k = jnp.linspace(k_min, k_max, n_pts) + vals = self._kValue_untrunc_func( + self.beta, + k, + self._knorm_bis, self._knorm, + self._r0, ) + # slope to match the interpolant onto an asymptotic expansion of kv + # that is kv(x) ~ sqrt(pi/2/x) * exp(-x) * (1 + slp/x) + aval = self._kValue_untrunc_asymp_func( + self.beta, k[-1], self._knorm_bis, self._r0 + ) + slp = (vals[-1] / aval - 1) * k[-1] * self._r0 + + return k, vals, akima_interp_coeffs(k, vals), slp + @jax.jit def _kValue(self, kpos): - """computation of the Moffat response in k-space with switch of truncated/untracated case + """computation of the Moffat response in k-space with interpolant + expansions kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image) """ - k = safe_sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq) + k = safe_sqrt(kpos.x**2 + kpos.y**2) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) - res = self._kValue_untrunc(k) + + k_, vals_, coeffs, slp = self._kValue_untrunc_interp_coeffs() + res = akima_interp(k, k_, vals_, coeffs, fixed_spacing=True) + k_msk = jnp.where(k > 0, k, k_[1]) + res = jnp.where( + k > k_[-1], + self._kValue_untrunc_asymp_func(self.beta, k_msk, self._knorm_bis, self._r0) + * (1.0 + slp / k_msk / self._r0), + res, + ) + return res.reshape(out_shape) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 4b8549c4..20d026e8 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -6,8 +6,9 @@ import jax_galsim as galsim -def test_moffat_comp_galsim_maxk(): - psfs = [ +@pytest.mark.parametrize( + "psf", + [ # Make sure to include all the specialized betas we have in C++ layer. # The scale_radius and flux don't matter, but vary themm too. # Note: We also specialize beta=1, but that seems to be impossible to realize, @@ -22,37 +23,70 @@ def test_moffat_comp_galsim_maxk(): galsim.Moffat(beta=1.22, scale_radius=23, flux=23), galsim.Moffat(beta=3.6, scale_radius=2, flux=23), galsim.Moffat(beta=12.9, scale_radius=5, flux=23), - ] - threshs = [1.0e-3, 1.0e-4, 0.03] - print("\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk") - for psf in psfs: - for thresh in threshs: - psf = psf.withGSParams(maxk_threshold=thresh) - gpsf = _galsim.Moffat( - beta=psf.beta, - scale_radius=psf.scale_radius, - flux=psf.flux, - trunc=psf.trunc, - ) - gpsf = gpsf.withGSParams(maxk_threshold=thresh) - fk = psf.kValue(psf.maxk, 0).real / psf.flux + ], +) +@pytest.mark.parametrize("thresh", [1.0e-4, 1.0e-3, 0.03]) +def test_moffat_comp_galsim_maxk(psf, thresh): + print( + "\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk", flush=True + ) + psf = psf.withGSParams(maxk_threshold=thresh) + gpsf = _galsim.Moffat( + beta=psf.beta, + scale_radius=psf.scale_radius, + flux=psf.flux, + trunc=psf.trunc, + ) + gpsf = gpsf.withGSParams(maxk_threshold=thresh) + fk = psf.kValue(psf.maxk, 0).real / psf.flux + maxk_test_val_one = jnp.minimum(1.0, psf.maxk) + maxk_test_val_pone = maxk_test_val_one / 10.0 - print( - f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}" - ) - np.testing.assert_allclose( - psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5 - ) - np.testing.assert_allclose( - psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5 - ) - np.testing.assert_allclose( - psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5 - ) - np.testing.assert_allclose( - psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5 - ) - np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) + print( + f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}", + flush=True, + ) + np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=1e-4, atol=0) + np.testing.assert_allclose( + psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + psf.kValue(0.0, maxk_test_val_pone), + gpsf.kValue(0.0, maxk_test_val_pone), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + psf.kValue(-maxk_test_val_one, 0.0), + gpsf.kValue(-maxk_test_val_one, 0.0), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + psf.kValue(maxk_test_val_one, 0.0), + gpsf.kValue(maxk_test_val_one, 0.0), + rtol=1e-5, + atol=1e-5, + ) + + np.testing.assert_allclose( + psf.kValue(0.0, 0.1), + gpsf.kValue(0.0, 0.1), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + psf.kValue(-1.0, 0.0), + gpsf.kValue(-1.0, 0.0), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + psf.kValue(1.0, 0.0), + gpsf.kValue(1.0, 0.0), + rtol=1e-5, + atol=1e-5, + ) @pytest.mark.test_in_float32