diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 9dd49dc1..7473ab13 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -251,8 +251,9 @@ def _temme_series_kve(v, z): z_sq = z * z logzo2 = jnp.log(z / 2.0) mu = -v * logzo2 - sinc_v = jnp.where(v == 0.0, 1.0, jnp.sin(jnp.pi * v) / (jnp.pi * v)) - sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu) / mu) + sinc_v = jnp.sinc(v) + mu_safe = jnp.where(mu != 0, mu, 1.0) + sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu_safe) / mu_safe) initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v initial_p = 0.5 * jnp.exp(mu) / gamma1pv_inv @@ -711,6 +712,7 @@ def t3(x): # x>8 return factor * (rc * (cx + sx) - y * rs * (sx - cx)) x = jnp.abs(x) + x_ = jnp.where(x != 0, x, 1.0) return jnp.select( - [x == 0, x <= 4, x <= 8, x > 8], [1, t1(x), t2(x), t3(x)], default=x + [x == 0, x <= 4, x <= 8, x > 8], [1, t1(x_), t2(x_), t3(x_)], default=x ).reshape(orig_shape) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 6961ad39..5296e4ed 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -63,58 +63,61 @@ def __init__(self, *args, **kwargs): % kwargs.keys() ) + # MRB: we donot run these code blocks since we do not support + # real-space convolutions and they break tracing # Check whether to perform real space convolution... # Start by checking if all objects have a hard edge. - hard_edge = True - for obj in args: - if not isinstance(obj, GSObject): - raise TypeError( - "Arguments to Convolution must be GSObjects, not %s" % obj - ) - if not obj.has_hard_edges: - hard_edge = False - - if real_space is None: - # The automatic determination is to use real_space if 2 items, both with hard edges. - if len(args) <= 2: - real_space = hard_edge - else: - real_space = False - elif bool(real_space) != real_space: - raise TypeError("real_space must be a boolean") - - # Warn if doing DFT convolution for objects with hard edges - if not real_space and hard_edge: - if len(args) == 2: - galsim_warn( - "Doing convolution of 2 objects, both with hard edges. " - "This might be more accurate with `real_space=True`, " - "but this functionality has not yet been implemented in JAX-Galsim." - ) - else: - galsim_warn( - "Doing convolution where all objects have hard edges. " - "There might be some inaccuracies due to ringing in k-space." - ) - if real_space: - # Can't do real space if nobj > 2 - if len(args) > 2: - galsim_warn( - "Real-space convolution of more than 2 objects is not implemented. " - "Switching to DFT method." - ) - real_space = False - - # Also can't do real space if any object is not analytic, so check for that. - else: - for obj in args: - if not obj.is_analytic_x: - galsim_warn( - "A component to be convolved is not analytic in real space. " - "Cannot use real space convolution. Switching to DFT method." - ) - real_space = False - break + # hard_edge = True + # for obj in args: + # if not isinstance(obj, GSObject): + # raise TypeError( + # "Arguments to Convolution must be GSObjects, not %s" % obj + # ) + # if not obj.has_hard_edges: + # hard_edge = False + + # if real_space is None: + # # The automatic determination is to use real_space if 2 items, both with hard edges. + # if len(args) <= 2: + # real_space = hard_edge + # else: + # real_space = False + # elif bool(real_space) != real_space: + # raise TypeError("real_space must be a boolean") + + # # Warn if doing DFT convolution for objects with hard edges + # if not real_space and hard_edge: + # if len(args) == 2: + # galsim_warn( + # "Doing convolution of 2 objects, both with hard edges. " + # "This might be more accurate with `real_space=True`, " + # "but this functionality has not yet been implemented in JAX-Galsim." + # ) + # else: + # galsim_warn( + # "Doing convolution where all objects have hard edges. " + # "There might be some inaccuracies due to ringing in k-space." + # ) + # if real_space: + # # Can't do real space if nobj > 2 + # if len(args) > 2: + # galsim_warn( + # "Real-space convolution of more than 2 objects is not implemented. " + # "Switching to DFT method." + # ) + # real_space = False + + # # Also can't do real space if any object is not analytic, so check for that. + # else: + # for obj in args: + # if not obj.is_analytic_x: + # galsim_warn( + # "A component to be convolved is not analytic in real space. " + # "Cannot use real space convolution. Switching to DFT method." + # ) + # real_space = False + # break + # MRB: end of commented out code blocks # Save the construction parameters (as they are at this point) as attributes so they # can be inspected later if necessary. diff --git a/jax_galsim/core/interpolate.py b/jax_galsim/core/interpolate.py index 7b545148..e2c5126d 100644 --- a/jax_galsim/core/interpolate.py +++ b/jax_galsim/core/interpolate.py @@ -138,7 +138,7 @@ def akima_interp(x, xp, yp, coeffs, fixed_spacing=False): The values of the Akima cubic spline at the points x. """ xp = jnp.asarray(xp) - # yp = jnp.array(yp) # unused + yp = jnp.asarray(yp) if fixed_spacing: dxp = xp[1] - xp[0] i = jnp.floor((x - xp[0]) / dxp).astype(jnp.int32) @@ -160,6 +160,6 @@ def akima_interp(x, xp, yp, coeffs, fixed_spacing=False): dx3 = dx2 * dx xval = a[i] + b[i] * dx + c[i] * dx2 + d[i] * dx3 - xval = jnp.where(x < xp[0], 0, xval) - xval = jnp.where(x > xp[-1], 0, xval) + xval = jnp.where(x < xp[0], yp[0], xval) + xval = jnp.where(x > xp[-1], yp[-1], xval) return xval diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index fe596398..cd3d3593 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -7,9 +7,9 @@ from jax_galsim.bessel import j0, kv from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral +from jax_galsim.core.interpolate import akima_interp, akima_interp_coeffs from jax_galsim.core.utils import bisect_for_root, ensure_hashable, implements from jax_galsim.gsobject import GSObject -from jax_galsim.position import PositionD from jax_galsim.random import UniformDeviate @@ -103,12 +103,13 @@ def __init__( fwhm=fwhm, ) else: + trunc_ = jnp.where(trunc > 0, trunc, 50.0) super().__init__( beta=beta, scale_radius=( - jax.lax.select( + jnp.where( trunc > 0, - _MoffatCalculateSRFromHLR(half_light_radius, trunc, beta), + _MoffatCalculateSRFromHLR(half_light_radius, trunc_, beta), half_light_radius / jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0), ) @@ -281,7 +282,19 @@ def _prefactor(self): @jax.jit def _maxk_func(self, k): return ( - jnp.abs(self._kValue(PositionD(x=k, y=0)).real / self.flux) + jnp.abs( + self._kValue_func( + self.beta, + jnp.atleast_1d(k), + self._knorm_bis, + self._knorm, + self._prefactor, + self._maxRrD, + self.trunc, + self._r0, + )[0].real + / self.flux + ) - self.gsparams.maxk_threshold ) @@ -336,36 +349,113 @@ def _xValue(self, pos): rsq > self._maxRrD_sq, 0.0, self._norm * jnp.power(1.0 + rsq, -self.beta) ) - def _kValue_untrunc(self, k): + @staticmethod + @jax.jit + def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm, _r0): """Non truncated version of _kValue""" + k_ = jnp.where(k > 0, k * _r0, 1.0) return jnp.where( k > 0, - self._knorm_bis * jnp.power(k, self.beta - 1.0) * _Knu(self.beta - 1.0, k), - self._knorm, + _knorm_bis * jnp.power(k_, beta - 1.0) * _Knu(beta - 1.0, k_), + _knorm, ) - def _kValue_trunc(self, k): + @staticmethod + @jax.jit + def _kValue_trunc_func(beta, k, _knorm, _prefactor, _maxRrD, _r0): """Truncated version of _kValue""" + k_ = k * _r0 + k_ = jnp.where(k_ <= 50.0, k_, 50.0) return jnp.where( - k <= 50.0, - self._knorm * self._prefactor * _hankel(k, self.beta, self._maxRrD), + k_ <= 50.0, + _knorm * _prefactor * _hankel(k_, beta, _maxRrD), 0.0, ) + @staticmethod + @jax.jit + def _kValue_func(beta, k, _knorm_bis, _knorm, _prefactor, _maxRrD, trunc, _r0): + return jax.lax.cond( + trunc > 0, + lambda x: Moffat._kValue_trunc_func( + beta, x, _knorm, _prefactor, _maxRrD, _r0 + ), + lambda x: Moffat._kValue_untrunc_func(beta, x, _knorm_bis, _knorm, _r0), + k, + ) + + @staticmethod + @jax.jit + def _kValue_untrunc_asymp_func(beta, k, _knorm_bis, _r0): + kr0 = k * _r0 + return ( + _knorm_bis + * jnp.power(kr0, beta - 1.0) + * jnp.exp(-kr0) + * jnp.sqrt(jnp.pi / 2 / kr0) + ) + + def _kValue_untrunc_interp_coeffs(self): + # this number of points gets the tests to pass + # I did not investigate further. + n_pts = 2000 + k_min = 0 + # this is a fudge factor to help numerical convergnce in the tests + # it should not be needed in principle since the profile is not + # evaluated above maxk, but it appears to be needed anyway and + # IDK why + k_max = self._maxk * 2 + k = jnp.linspace(k_min, k_max, n_pts) + vals = self._kValue_untrunc_func( + self.beta, + k, + self._knorm_bis, + self._knorm, + self._r0, + ) + + # slope to match the interpolant onto an asymptotic expansion of kv + # that is kv(x) ~ sqrt(pi/2/x) * exp(-x) * (1 + slp/x) + aval = self._kValue_untrunc_asymp_func( + self.beta, k[-1], self._knorm_bis, self._r0 + ) + slp = (vals[-1] / aval - 1) * k[-1] * self._r0 + + return k, vals, akima_interp_coeffs(k, vals), slp + @jax.jit def _kValue(self, kpos): """computation of the Moffat response in k-space with switch of truncated/untracated case kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image) """ - k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq) + k = jnp.sqrt((kpos.x**2 + kpos.y**2)) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) + + # for untruncated profiles, we interpolate and use and asymptotic + # expansion for extrapolation + def _run_untrunc(krun): + k_, vals_, coeffs, slp = self._kValue_untrunc_interp_coeffs() + res = akima_interp(krun, k_, vals_, coeffs, fixed_spacing=True) + krun = jnp.where(krun > 0, krun, k_[1]) + return jnp.where( + krun > k_[-1], + self._kValue_untrunc_asymp_func( + self.beta, krun, self._knorm_bis, self._r0 + ) + * (1.0 + slp / krun / self._r0), + res, + ) + res = jax.lax.cond( self.trunc > 0, - lambda x: self._kValue_trunc(x), - lambda x: self._kValue_untrunc(x), + lambda x: Moffat._kValue_trunc_func( + self.beta, x, self._knorm, self._prefactor, self._maxRrD, self._r0 + ), + lambda x: _run_untrunc(x), k, ) + return res.reshape(out_shape) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/tests/GalSim b/tests/GalSim index 3251a393..04918b11 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 3251a393bf7ea94fe9ccda3508bc7db722eca1cf +Subproject commit 04918b118926eafc01ec9403b8afed29fb918d51 diff --git a/tests/jax/test_derivs_params.py b/tests/jax/test_derivs_params.py new file mode 100644 index 00000000..ffb33fa1 --- /dev/null +++ b/tests/jax/test_derivs_params.py @@ -0,0 +1,95 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import jax_galsim as jgs + + +@pytest.mark.parametrize( + "params,gsobj,args", + [ + (["scale_radius", "half_light_radius"], jgs.Spergel, [1.0]), + (["scale_radius", "half_light_radius"], jgs.Exponential, []), + (["sigma", "fwhm", "half_light_radius"], jgs.Gaussian, []), + (["scale_radius", "half_light_radius", "fwhm"], jgs.Moffat, [2.0]), + ], +) +def test_deriv_params_gsobject(params, gsobj, args): + val = 2.0 + eps = 1e-5 + + for param in params: + print("\nparam:", param, flush=True) + + def _run(val_): + kwargs = {param: val_} + return jnp.max( + gsobj( + *args, + **kwargs, + gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), + ) + .drawImage(nx=48, ny=48, scale=0.2, method="fft") + .array[24, 24] + ** 2 + ) + + gfunc = jax.jit(jax.grad(_run)) + gval = gfunc(val) + + gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps + + np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) + + +def test_deriv_params_moffat_with_trunc(): + val = 2.0 + eps = 1e-5 + + def _run(val_): + return jnp.max( + jgs.Moffat( + 2.5, + half_light_radius=val_, + trunc=20.0, + gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), + ) + .drawImage(nx=48, ny=48, scale=0.2) + .array[24, 24] + ** 2 + ) + + gfunc = jax.jit(jax.grad(_run)) + with jax.disable_jit(), jax.debug_nans(): + gval = gfunc(val) + + gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps + + np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) + + +def test_deriv_params_moffat_with_respect_to_trunc(): + val = 20.0 + eps = 1e-5 + + def _run(val_): + return jnp.max( + jgs.Moffat( + 2.5, + half_light_radius=2.0, + trunc=val_, + gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), + ) + .drawImage(nx=48, ny=48, scale=0.2) + .array[24, 24] + ** 2 + ) + + gfunc = jax.jit(jax.grad(_run)) + with jax.disable_jit(), jax.debug_nans(): + gval = gfunc(val) + + gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps + + np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 04376cd3..c5b04b27 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -6,8 +6,9 @@ import jax_galsim as galsim -def test_moffat_comp_galsim_maxk(): - psfs = [ +@pytest.mark.parametrize( + "psf", + [ # Make sure to include all the specialized betas we have in C++ layer. # The scale_radius and flux don't matter, but vary themm too. # Note: We also specialize beta=1, but that seems to be impossible to realize, @@ -25,37 +26,51 @@ def test_moffat_comp_galsim_maxk(): galsim.Moffat(beta=1.22, scale_radius=7, flux=23, trunc=30), galsim.Moffat(beta=3.6, scale_radius=9, flux=23, trunc=50), galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000), - ] - threshs = [1.0e-3, 1.0e-4, 0.03] - print("\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk") - for psf in psfs: - for thresh in threshs: - psf = psf.withGSParams(maxk_threshold=thresh) - gpsf = _galsim.Moffat( - beta=psf.beta, - scale_radius=psf.scale_radius, - flux=psf.flux, - trunc=psf.trunc, - ) - gpsf = gpsf.withGSParams(maxk_threshold=thresh) - fk = psf.kValue(psf.maxk, 0).real / psf.flux + ], +) +@pytest.mark.parametrize("thresh", [1.0e-4, 1.0e-3, 0.03]) +def test_moffat_comp_galsim_maxk(psf, thresh): + print( + "\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk", flush=True + ) + psf = psf.withGSParams(maxk_threshold=thresh) + gpsf = _galsim.Moffat( + beta=psf.beta, + scale_radius=psf.scale_radius, + flux=psf.flux, + trunc=psf.trunc, + ) + gpsf = gpsf.withGSParams(maxk_threshold=thresh) + fk = psf.kValue(psf.maxk, 0).real / psf.flux + maxk_test_val_one = jnp.minimum(1.0, psf.maxk) + maxk_test_val_pone = maxk_test_val_one / 10.0 - print( - f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}" - ) - np.testing.assert_allclose( - psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5 - ) - np.testing.assert_allclose( - psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5 - ) - np.testing.assert_allclose( - psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5 - ) - np.testing.assert_allclose( - psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5 - ) - np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) + print( + f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}", + flush=True, + ) + np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) + np.testing.assert_allclose( + psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + psf.kValue(0.0, maxk_test_val_pone), + gpsf.kValue(0.0, maxk_test_val_pone), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + psf.kValue(-maxk_test_val_one, 0.0), + gpsf.kValue(-maxk_test_val_one, 0.0), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + psf.kValue(maxk_test_val_one, 0.0), + gpsf.kValue(maxk_test_val_one, 0.0), + rtol=1e-5, + atol=1e-5, + ) @pytest.mark.test_in_float32