From bb69aa4f7ad130c378b59a1127c3c8d1d63f9598 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Feb 2026 07:11:39 -0600 Subject: [PATCH 01/18] fix: ensure moffat derivs are not nan --- jax_galsim/bessel.py | 5 +- jax_galsim/moffat.py | 90 ++++++++++++++++++++++++---- tests/jax/test_derivs_params.py | 43 +++++++++++++ tests/jax/test_moffat_comp_galsim.py | 8 +-- 4 files changed, 128 insertions(+), 18 deletions(-) create mode 100644 tests/jax/test_derivs_params.py diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 9dd49dc1..6dc8f7c1 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -251,8 +251,9 @@ def _temme_series_kve(v, z): z_sq = z * z logzo2 = jnp.log(z / 2.0) mu = -v * logzo2 - sinc_v = jnp.where(v == 0.0, 1.0, jnp.sin(jnp.pi * v) / (jnp.pi * v)) - sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu) / mu) + sinc_v = jnp.sinc(v) + mu_safe = jnp.where(mu > 0, mu, 1.0) + sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu_safe) / mu_safe) initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v initial_p = 0.5 * jnp.exp(mu) / gamma1pv_inv diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index fe596398..8db93ac8 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -7,10 +7,11 @@ from jax_galsim.bessel import j0, kv from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral +from jax_galsim.core.interpolate import akima_interp, akima_interp_coeffs from jax_galsim.core.utils import bisect_for_root, ensure_hashable, implements from jax_galsim.gsobject import GSObject -from jax_galsim.position import PositionD from jax_galsim.random import UniformDeviate +from jax_galsim.utilities import lazy_property @jax.jit @@ -103,12 +104,13 @@ def __init__( fwhm=fwhm, ) else: + trunc_ = jnp.where(trunc > 0, trunc, 100.0) super().__init__( beta=beta, scale_radius=( jax.lax.select( trunc > 0, - _MoffatCalculateSRFromHLR(half_light_radius, trunc, beta), + _MoffatCalculateSRFromHLR(half_light_radius, trunc_, beta), half_light_radius / jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0), ) @@ -281,7 +283,19 @@ def _prefactor(self): @jax.jit def _maxk_func(self, k): return ( - jnp.abs(self._kValue(PositionD(x=k, y=0)).real / self.flux) + jnp.abs( + self._kValue_func( + self.beta, + jnp.atleast_1d(k), + self._knorm_bis, + self._knorm, + self._prefactor, + self._maxRrD, + self.trunc, + self._r0, + )[0].real + / self.flux + ) - self.gsparams.maxk_threshold ) @@ -336,19 +350,74 @@ def _xValue(self, pos): rsq > self._maxRrD_sq, 0.0, self._norm * jnp.power(1.0 + rsq, -self.beta) ) + @staticmethod + @jax.jit + def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm): + """Non truncated version of _kValue""" + k_ = jnp.where(k > 0, k, 1.0) + return jnp.where( + k > 0, + _knorm_bis * jnp.power(k_, beta - 1.0) * _Knu(beta - 1.0, k_), + _knorm, + ) + + @staticmethod + @jax.jit + def _kValue_trunc_func(beta, k, _knorm, _prefactor, _maxRrD): + """Truncated version of _kValue""" + k_ = jnp.where(k <= 50.0, k, 50.0) + return jnp.where( + k <= 50.0, + _knorm * _prefactor * _hankel(k_, beta, _maxRrD), + 0.0, + ) + + @staticmethod + @jax.jit + def _kValue_func(beta, k, _knorm_bis, _knorm, _prefactor, _maxRrD, trunc, _r0): + return jax.lax.cond( + trunc > 0, + lambda x: Moffat._kValue_trunc_func(beta, x, _knorm, _prefactor, _maxRrD), + lambda x: Moffat._kValue_untrunc_func(beta, x, _knorm_bis, _knorm), + k * _r0, + ) + + @lazy_property + def _kValue_interp_coeffs(self): + n_pts = 1000 + k_min = 0.0 + k_max = self._maxk + k = jnp.linspace(k_min, k_max, n_pts) + vals = self._kValue_func( + self.beta, + k, + self._knorm_bis, + self._knorm, + self._prefactor, + self._maxRrD, + self.trunc, + self._r0, + ) + + return k, vals, akima_interp_coeffs(k, vals) + def _kValue_untrunc(self, k): """Non truncated version of _kValue""" + k_ = jnp.where(k > 0, k, 1.0) return jnp.where( k > 0, - self._knorm_bis * jnp.power(k, self.beta - 1.0) * _Knu(self.beta - 1.0, k), + self._knorm_bis + * jnp.power(k_, self.beta - 1.0) + * _Knu(self.beta - 1.0, k_), self._knorm, ) def _kValue_trunc(self, k): """Truncated version of _kValue""" + k_ = jnp.where(k <= 50.0, k, 50.0) return jnp.where( k <= 50.0, - self._knorm * self._prefactor * _hankel(k, self.beta, self._maxRrD), + self._knorm * self._prefactor * _hankel(k_, self.beta, self._maxRrD), 0.0, ) @@ -357,15 +426,12 @@ def _kValue(self, kpos): """computation of the Moffat response in k-space with switch of truncated/untracated case kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image) """ - k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq) + k = jnp.sqrt((kpos.x**2 + kpos.y**2)) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) - res = jax.lax.cond( - self.trunc > 0, - lambda x: self._kValue_trunc(x), - lambda x: self._kValue_untrunc(x), - k, - ) + k_, vals_, coeffs = self._kValue_interp_coeffs + res = akima_interp(k, k_, vals_, coeffs, fixed_spacing=True) + return res.reshape(out_shape) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/tests/jax/test_derivs_params.py b/tests/jax/test_derivs_params.py new file mode 100644 index 00000000..3dffd5b0 --- /dev/null +++ b/tests/jax/test_derivs_params.py @@ -0,0 +1,43 @@ +import jax +import jax.numpy as jnp +import numpy as np + +import jax_galsim as jgs + +import pytest + + +@pytest.mark.parametrize( + "params,gsobj,args", + [ + (["scale_radius", "half_light_radius"], jgs.Spergel, [1.0]), + (["scale_radius", "half_light_radius"], jgs.Exponential, []), + (["sigma", "fwhm", "half_light_radius"], jgs.Gaussian, []), + (["scale_radius", "half_light_radius", "fwhm"], jgs.Moffat, [2.0]), + ] +) +def test_deriv_params_gsobject(params, gsobj, args): + val = 2.0 + eps = 1e-5 + + for param in params: + print("\nparam:", param, flush=True) + + def _run(val_): + kwargs = {param: val_} + return jnp.max( + gsobj( + *args, + **kwargs, + gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), + ).drawImage(nx=48, ny=48, scale=0.2).array[24, 24]**2 + ) + + gfunc = jax.jit(jax.grad(_run)) + gval = gfunc(val) + + gfdiff = ( + _run(val + eps) - _run(val - eps) + ) / 2.0 / eps + + np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 04376cd3..1f2be63e 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -44,16 +44,16 @@ def test_moffat_comp_galsim_maxk(): f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}" ) np.testing.assert_allclose( - psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5 + psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-8 ) np.testing.assert_allclose( - psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5 + psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5, atol=1e-8 ) np.testing.assert_allclose( - psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5 + psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5, atol=1e-8 ) np.testing.assert_allclose( - psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5 + psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5, atol=1e-8 ) np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) From fbe4dac7542a943babd5de876a5537b27217a35a Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 10 Feb 2026 08:18:44 -0600 Subject: [PATCH 02/18] Update bessel.py --- jax_galsim/bessel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 6dc8f7c1..d947a8c7 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -252,7 +252,7 @@ def _temme_series_kve(v, z): logzo2 = jnp.log(z / 2.0) mu = -v * logzo2 sinc_v = jnp.sinc(v) - mu_safe = jnp.where(mu > 0, mu, 1.0) + mu_safe = jnp.where(mu != 0, mu, 1.0) sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu_safe) / mu_safe) initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v From 7ef38d4c4c6435a4163a83f1a8b707dbb2ad0d59 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 14:19:24 +0000 Subject: [PATCH 03/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_derivs_params.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/jax/test_derivs_params.py b/tests/jax/test_derivs_params.py index 3dffd5b0..83438bc7 100644 --- a/tests/jax/test_derivs_params.py +++ b/tests/jax/test_derivs_params.py @@ -1,11 +1,10 @@ import jax import jax.numpy as jnp import numpy as np +import pytest import jax_galsim as jgs -import pytest - @pytest.mark.parametrize( "params,gsobj,args", @@ -14,7 +13,7 @@ (["scale_radius", "half_light_radius"], jgs.Exponential, []), (["sigma", "fwhm", "half_light_radius"], jgs.Gaussian, []), (["scale_radius", "half_light_radius", "fwhm"], jgs.Moffat, [2.0]), - ] + ], ) def test_deriv_params_gsobject(params, gsobj, args): val = 2.0 @@ -30,14 +29,15 @@ def _run(val_): *args, **kwargs, gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), - ).drawImage(nx=48, ny=48, scale=0.2).array[24, 24]**2 + ) + .drawImage(nx=48, ny=48, scale=0.2) + .array[24, 24] + ** 2 ) gfunc = jax.jit(jax.grad(_run)) gval = gfunc(val) - gfdiff = ( - _run(val + eps) - _run(val - eps) - ) / 2.0 / eps + gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) From ea7bde85b307f42dfa1f014e8ae758a742846fd9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Feb 2026 11:03:12 -0600 Subject: [PATCH 04/18] fix: do not use lazy property --- jax_galsim/core/interpolate.py | 6 +-- jax_galsim/moffat.py | 10 ++-- tests/jax/test_moffat_comp_galsim.py | 79 +++++++++++++++++----------- 3 files changed, 54 insertions(+), 41 deletions(-) diff --git a/jax_galsim/core/interpolate.py b/jax_galsim/core/interpolate.py index 7b545148..e2c5126d 100644 --- a/jax_galsim/core/interpolate.py +++ b/jax_galsim/core/interpolate.py @@ -138,7 +138,7 @@ def akima_interp(x, xp, yp, coeffs, fixed_spacing=False): The values of the Akima cubic spline at the points x. """ xp = jnp.asarray(xp) - # yp = jnp.array(yp) # unused + yp = jnp.asarray(yp) if fixed_spacing: dxp = xp[1] - xp[0] i = jnp.floor((x - xp[0]) / dxp).astype(jnp.int32) @@ -160,6 +160,6 @@ def akima_interp(x, xp, yp, coeffs, fixed_spacing=False): dx3 = dx2 * dx xval = a[i] + b[i] * dx + c[i] * dx2 + d[i] * dx3 - xval = jnp.where(x < xp[0], 0, xval) - xval = jnp.where(x > xp[-1], 0, xval) + xval = jnp.where(x < xp[0], yp[0], xval) + xval = jnp.where(x > xp[-1], yp[-1], xval) return xval diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 8db93ac8..c6943dcc 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -11,7 +11,6 @@ from jax_galsim.core.utils import bisect_for_root, ensure_hashable, implements from jax_galsim.gsobject import GSObject from jax_galsim.random import UniformDeviate -from jax_galsim.utilities import lazy_property @jax.jit @@ -382,10 +381,10 @@ def _kValue_func(beta, k, _knorm_bis, _knorm, _prefactor, _maxRrD, trunc, _r0): k * _r0, ) - @lazy_property + @jax.jit def _kValue_interp_coeffs(self): - n_pts = 1000 - k_min = 0.0 + n_pts = 5000 + k_min = 0 k_max = self._maxk k = jnp.linspace(k_min, k_max, n_pts) vals = self._kValue_func( @@ -398,7 +397,6 @@ def _kValue_interp_coeffs(self): self.trunc, self._r0, ) - return k, vals, akima_interp_coeffs(k, vals) def _kValue_untrunc(self, k): @@ -429,7 +427,7 @@ def _kValue(self, kpos): k = jnp.sqrt((kpos.x**2 + kpos.y**2)) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) - k_, vals_, coeffs = self._kValue_interp_coeffs + k_, vals_, coeffs = self._kValue_interp_coeffs() res = akima_interp(k, k_, vals_, coeffs, fixed_spacing=True) return res.reshape(out_shape) diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 1f2be63e..cb75dd06 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -6,8 +6,9 @@ import jax_galsim as galsim -def test_moffat_comp_galsim_maxk(): - psfs = [ +@pytest.mark.parametrize( + "psf", + [ # Make sure to include all the specialized betas we have in C++ layer. # The scale_radius and flux don't matter, but vary themm too. # Note: We also specialize beta=1, but that seems to be impossible to realize, @@ -25,37 +26,51 @@ def test_moffat_comp_galsim_maxk(): galsim.Moffat(beta=1.22, scale_radius=7, flux=23, trunc=30), galsim.Moffat(beta=3.6, scale_radius=9, flux=23, trunc=50), galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000), - ] - threshs = [1.0e-3, 1.0e-4, 0.03] - print("\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk") - for psf in psfs: - for thresh in threshs: - psf = psf.withGSParams(maxk_threshold=thresh) - gpsf = _galsim.Moffat( - beta=psf.beta, - scale_radius=psf.scale_radius, - flux=psf.flux, - trunc=psf.trunc, - ) - gpsf = gpsf.withGSParams(maxk_threshold=thresh) - fk = psf.kValue(psf.maxk, 0).real / psf.flux + ], +) +@pytest.mark.parametrize("thresh", [1.0e-4, 1.0e-3, 0.03]) +def test_moffat_comp_galsim_maxk(psf, thresh): + print( + "\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk", flush=True + ) + psf = psf.withGSParams(maxk_threshold=thresh) + gpsf = _galsim.Moffat( + beta=psf.beta, + scale_radius=psf.scale_radius, + flux=psf.flux, + trunc=psf.trunc, + ) + gpsf = gpsf.withGSParams(maxk_threshold=thresh) + fk = psf.kValue(psf.maxk, 0).real / psf.flux + maxk_test_val_one = jnp.minimum(1.0, psf.maxk) + maxk_test_val_pone = maxk_test_val_one / 10.0 - print( - f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}" - ) - np.testing.assert_allclose( - psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-8 - ) - np.testing.assert_allclose( - psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5, atol=1e-8 - ) - np.testing.assert_allclose( - psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5, atol=1e-8 - ) - np.testing.assert_allclose( - psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5, atol=1e-8 - ) - np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) + print( + f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}", + flush=True, + ) + np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) + np.testing.assert_allclose( + psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-8 + ) + np.testing.assert_allclose( + psf.kValue(0.0, maxk_test_val_pone), + gpsf.kValue(0.0, maxk_test_val_pone), + rtol=1e-5, + atol=1e-8, + ) + np.testing.assert_allclose( + psf.kValue(-maxk_test_val_one, 0.0), + gpsf.kValue(-maxk_test_val_one, 0.0), + rtol=1e-5, + atol=1e-8, + ) + np.testing.assert_allclose( + psf.kValue(maxk_test_val_one, 0.0), + gpsf.kValue(maxk_test_val_one, 0.0), + rtol=1e-5, + atol=1e-8, + ) @pytest.mark.test_in_float32 From 0a02281c3ec3d02f30a0685814715418bb50cbbc Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Feb 2026 12:37:21 -0600 Subject: [PATCH 05/18] fix: ensure tests pass by having enough k range and tons of points --- jax_galsim/moffat.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index c6943dcc..2d7cea38 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -383,9 +383,15 @@ def _kValue_func(beta, k, _knorm_bis, _knorm, _prefactor, _maxRrD, trunc, _r0): @jax.jit def _kValue_interp_coeffs(self): + # this number of points gets the tests to pass + # I did not investigate further. n_pts = 5000 k_min = 0 - k_max = self._maxk + # this is a fudge factor to help numerical convergnce in the tests + # it should not be needed in principle since the profile is not + # evaluated above maxk, but it appears to be needed anyway and + # IDK why + k_max = jnp.minimum(self._maxk * 2, 50.0) k = jnp.linspace(k_min, k_max, n_pts) vals = self._kValue_func( self.beta, From 2d7b9bfe09ef75755e242ea9d94b1bb63af9b974 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 10 Feb 2026 12:38:34 -0600 Subject: [PATCH 06/18] fix: match the maxk for truncation in other parts of the code --- jax_galsim/moffat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 2d7cea38..daca4461 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -103,7 +103,7 @@ def __init__( fwhm=fwhm, ) else: - trunc_ = jnp.where(trunc > 0, trunc, 100.0) + trunc_ = jnp.where(trunc > 0, trunc, 50.0) super().__init__( beta=beta, scale_radius=( From e1be2763277e6487f16ec78d7d31b80666216c3e Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Feb 2026 14:00:50 -0600 Subject: [PATCH 07/18] fix: do not cap this --- jax_galsim/moffat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 2d7cea38..31527037 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -391,7 +391,7 @@ def _kValue_interp_coeffs(self): # it should not be needed in principle since the profile is not # evaluated above maxk, but it appears to be needed anyway and # IDK why - k_max = jnp.minimum(self._maxk * 2, 50.0) + k_max = self._maxk * 2 k = jnp.linspace(k_min, k_max, n_pts) vals = self._kValue_func( self.beta, From a5e799ebea85eabb517697177de70502c46cac29 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Feb 2026 17:17:50 -0600 Subject: [PATCH 08/18] fix: get right fudge factor --- jax_galsim/moffat.py | 2 +- tests/GalSim | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 73edc140..f3792184 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -391,7 +391,7 @@ def _kValue_interp_coeffs(self): # it should not be needed in principle since the profile is not # evaluated above maxk, but it appears to be needed anyway and # IDK why - k_max = self._maxk * 2 + k_max = self._maxk * 5 k = jnp.linspace(k_min, k_max, n_pts) vals = self._kValue_func( self.beta, diff --git a/tests/GalSim b/tests/GalSim index 3251a393..04918b11 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 3251a393bf7ea94fe9ccda3508bc7db722eca1cf +Subproject commit 04918b118926eafc01ec9403b8afed29fb918d51 From 33f12a61ddad10da85c2e5075b6e7142de3be95a Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 10 Feb 2026 17:20:11 -0600 Subject: [PATCH 09/18] fix: adjust tests again --- tests/jax/test_moffat_comp_galsim.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index cb75dd06..c5b04b27 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -51,25 +51,25 @@ def test_moffat_comp_galsim_maxk(psf, thresh): ) np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0) np.testing.assert_allclose( - psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-8 + psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-5 ) np.testing.assert_allclose( psf.kValue(0.0, maxk_test_val_pone), gpsf.kValue(0.0, maxk_test_val_pone), rtol=1e-5, - atol=1e-8, + atol=1e-5, ) np.testing.assert_allclose( psf.kValue(-maxk_test_val_one, 0.0), gpsf.kValue(-maxk_test_val_one, 0.0), rtol=1e-5, - atol=1e-8, + atol=1e-5, ) np.testing.assert_allclose( psf.kValue(maxk_test_val_one, 0.0), gpsf.kValue(maxk_test_val_one, 0.0), rtol=1e-5, - atol=1e-8, + atol=1e-5, ) From 68d43342d4115131800ca547d94e96e6ec2953a5 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Feb 2026 07:58:32 -0600 Subject: [PATCH 10/18] feat: use aymptotic expansion with untruncated profiles --- jax_galsim/moffat.py | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index f3792184..9f433ea4 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -381,7 +381,17 @@ def _kValue_func(beta, k, _knorm_bis, _knorm, _prefactor, _maxRrD, trunc, _r0): k * _r0, ) + @staticmethod @jax.jit + def _kValue_func_untrunc_asymp(beta, k, _knorm_bis, _r0): + kr0 = k * _r0 + return ( + _knorm_bis + * jnp.power(kr0, beta - 1.0) + * jnp.exp(-kr0) + * jnp.sqrt(jnp.pi / 2 / kr0) + ) + def _kValue_interp_coeffs(self): # this number of points gets the tests to pass # I did not investigate further. @@ -391,7 +401,7 @@ def _kValue_interp_coeffs(self): # it should not be needed in principle since the profile is not # evaluated above maxk, but it appears to be needed anyway and # IDK why - k_max = self._maxk * 5 + k_max = self._maxk * 2 k = jnp.linspace(k_min, k_max, n_pts) vals = self._kValue_func( self.beta, @@ -403,7 +413,15 @@ def _kValue_interp_coeffs(self): self.trunc, self._r0, ) - return k, vals, akima_interp_coeffs(k, vals) + + # slope to match the interpolant onto an asymptotic expansion of kv + # that is kv(x) ~ sqrt(pi/2/x) * exp(-x) * (1 + slp/x) + aval = self._kValue_func_untrunc_asymp( + self.beta, k[-1], self._knorm_bis, self._r0 + ) + slp = (vals[-1] / aval - 1) * k[-1] * self._r0 + + return k, vals, akima_interp_coeffs(k, vals), slp def _kValue_untrunc(self, k): """Non truncated version of _kValue""" @@ -433,9 +451,22 @@ def _kValue(self, kpos): k = jnp.sqrt((kpos.x**2 + kpos.y**2)) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) - k_, vals_, coeffs = self._kValue_interp_coeffs() + k_, vals_, coeffs, _ = self._kValue_interp_coeffs() res = akima_interp(k, k_, vals_, coeffs, fixed_spacing=True) + res = jax.lax.cond( + self.trunc > 0, + lambda x: res, + lambda x: jnp.where( + x > k_[-1], + self._kValue_func_untrunc_asymp( + self.beta, x, self._knorm_bis, self._r0 + ), + res, + ), + k, + ) + return res.reshape(out_shape) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): From 8db2728631ea72c1981333c5b47895eada60217e Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Feb 2026 08:24:02 -0600 Subject: [PATCH 11/18] fix: use slope properly --- jax_galsim/moffat.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 9f433ea4..d49788c2 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -351,9 +351,9 @@ def _xValue(self, pos): @staticmethod @jax.jit - def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm): + def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm, _r0): """Non truncated version of _kValue""" - k_ = jnp.where(k > 0, k, 1.0) + k_ = jnp.where(k > 0, k, 1.0) * _r0 return jnp.where( k > 0, _knorm_bis * jnp.power(k_, beta - 1.0) * _Knu(beta - 1.0, k_), @@ -362,9 +362,9 @@ def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm): @staticmethod @jax.jit - def _kValue_trunc_func(beta, k, _knorm, _prefactor, _maxRrD): + def _kValue_trunc_func(beta, k, _knorm, _prefactor, _maxRrD, _r0): """Truncated version of _kValue""" - k_ = jnp.where(k <= 50.0, k, 50.0) + k_ = jnp.where(k <= 50.0, k, 50.0) * _r0 return jnp.where( k <= 50.0, _knorm * _prefactor * _hankel(k_, beta, _maxRrD), @@ -376,14 +376,16 @@ def _kValue_trunc_func(beta, k, _knorm, _prefactor, _maxRrD): def _kValue_func(beta, k, _knorm_bis, _knorm, _prefactor, _maxRrD, trunc, _r0): return jax.lax.cond( trunc > 0, - lambda x: Moffat._kValue_trunc_func(beta, x, _knorm, _prefactor, _maxRrD), - lambda x: Moffat._kValue_untrunc_func(beta, x, _knorm_bis, _knorm), - k * _r0, + lambda x: Moffat._kValue_trunc_func( + beta, x, _knorm, _prefactor, _maxRrD, _r0 + ), + lambda x: Moffat._kValue_untrunc_func(beta, x, _knorm_bis, _knorm, _r0), + k, ) @staticmethod @jax.jit - def _kValue_func_untrunc_asymp(beta, k, _knorm_bis, _r0): + def _kValue_untrunc_asymp_func(beta, k, _knorm_bis, _r0): kr0 = k * _r0 return ( _knorm_bis @@ -416,7 +418,7 @@ def _kValue_interp_coeffs(self): # slope to match the interpolant onto an asymptotic expansion of kv # that is kv(x) ~ sqrt(pi/2/x) * exp(-x) * (1 + slp/x) - aval = self._kValue_func_untrunc_asymp( + aval = self._kValue_untrunc_asymp_func( self.beta, k[-1], self._knorm_bis, self._r0 ) slp = (vals[-1] / aval - 1) * k[-1] * self._r0 @@ -425,7 +427,7 @@ def _kValue_interp_coeffs(self): def _kValue_untrunc(self, k): """Non truncated version of _kValue""" - k_ = jnp.where(k > 0, k, 1.0) + k_ = jnp.where(k > 0, k, 1.0) * self._r0 return jnp.where( k > 0, self._knorm_bis @@ -436,7 +438,7 @@ def _kValue_untrunc(self, k): def _kValue_trunc(self, k): """Truncated version of _kValue""" - k_ = jnp.where(k <= 50.0, k, 50.0) + k_ = jnp.where(k <= 50.0, k, 50.0) * self._r0 return jnp.where( k <= 50.0, self._knorm * self._prefactor * _hankel(k_, self.beta, self._maxRrD), @@ -451,7 +453,7 @@ def _kValue(self, kpos): k = jnp.sqrt((kpos.x**2 + kpos.y**2)) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) - k_, vals_, coeffs, _ = self._kValue_interp_coeffs() + k_, vals_, coeffs, slp = self._kValue_interp_coeffs() res = akima_interp(k, k_, vals_, coeffs, fixed_spacing=True) res = jax.lax.cond( @@ -459,9 +461,8 @@ def _kValue(self, kpos): lambda x: res, lambda x: jnp.where( x > k_[-1], - self._kValue_func_untrunc_asymp( - self.beta, x, self._knorm_bis, self._r0 - ), + self._kValue_untrunc_asymp_func(self.beta, x, self._knorm_bis, self._r0) + * (1.0 + slp / x / self._r0), res, ), k, From be849d3c3fc8defa8b7a2bf0d6892b244554fad2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Feb 2026 08:30:04 -0600 Subject: [PATCH 12/18] fix: nan derivs needs mask at k=0 --- jax_galsim/moffat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index d49788c2..73ca7a19 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -465,7 +465,7 @@ def _kValue(self, kpos): * (1.0 + slp / x / self._r0), res, ), - k, + jnp.where(k > 0, k, k_[1]), ) return res.reshape(out_shape) From 51421d7f7cbfbd7206389f10c8df762c307b5109 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Feb 2026 09:53:20 -0600 Subject: [PATCH 13/18] fix: wrong way to truncate --- jax_galsim/moffat.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 73ca7a19..4a3b2c05 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -353,9 +353,10 @@ def _xValue(self, pos): @jax.jit def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm, _r0): """Non truncated version of _kValue""" - k_ = jnp.where(k > 0, k, 1.0) * _r0 + k_ = k * _r0 + k_ = jnp.where(k_ > 0, k_, 1.0) return jnp.where( - k > 0, + k_ > 0, _knorm_bis * jnp.power(k_, beta - 1.0) * _Knu(beta - 1.0, k_), _knorm, ) @@ -364,9 +365,10 @@ def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm, _r0): @jax.jit def _kValue_trunc_func(beta, k, _knorm, _prefactor, _maxRrD, _r0): """Truncated version of _kValue""" - k_ = jnp.where(k <= 50.0, k, 50.0) * _r0 + k_ = k * _r0 + k_ = jnp.where(k_ <= 50.0, k_, 50.0) return jnp.where( - k <= 50.0, + k_ <= 50.0, _knorm * _prefactor * _hankel(k_, beta, _maxRrD), 0.0, ) From 3b4cf52347b77b4cf4737f73fb3059f17d57e1ca Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Feb 2026 10:48:50 -0600 Subject: [PATCH 14/18] fix: wrong value at k = 0 --- jax_galsim/moffat.py | 45 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 4a3b2c05..069323b6 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -353,10 +353,9 @@ def _xValue(self, pos): @jax.jit def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm, _r0): """Non truncated version of _kValue""" - k_ = k * _r0 - k_ = jnp.where(k_ > 0, k_, 1.0) + k_ = jnp.where(k > 0, k * _r0, 1.0) return jnp.where( - k_ > 0, + k > 0, _knorm_bis * jnp.power(k_, beta - 1.0) * _Knu(beta - 1.0, k_), _knorm, ) @@ -427,25 +426,25 @@ def _kValue_interp_coeffs(self): return k, vals, akima_interp_coeffs(k, vals), slp - def _kValue_untrunc(self, k): - """Non truncated version of _kValue""" - k_ = jnp.where(k > 0, k, 1.0) * self._r0 - return jnp.where( - k > 0, - self._knorm_bis - * jnp.power(k_, self.beta - 1.0) - * _Knu(self.beta - 1.0, k_), - self._knorm, - ) - - def _kValue_trunc(self, k): - """Truncated version of _kValue""" - k_ = jnp.where(k <= 50.0, k, 50.0) * self._r0 - return jnp.where( - k <= 50.0, - self._knorm * self._prefactor * _hankel(k_, self.beta, self._maxRrD), - 0.0, - ) + # def _kValue_untrunc(self, k): + # """Non truncated version of _kValue""" + # k_ = jnp.where(k > 0, k, 1.0) * self._r0 + # return jnp.where( + # k > 0, + # self._knorm_bis + # * jnp.power(k_, self.beta - 1.0) + # * _Knu(self.beta - 1.0, k_), + # self._knorm, + # ) + + # def _kValue_trunc(self, k): + # """Truncated version of _kValue""" + # k_ = jnp.where(k <= 50.0, k, 50.0) * self._r0 + # return jnp.where( + # k <= 50.0, + # self._knorm * self._prefactor * _hankel(k_, self.beta, self._maxRrD), + # 0.0, + # ) @jax.jit def _kValue(self, kpos): @@ -462,7 +461,7 @@ def _kValue(self, kpos): self.trunc > 0, lambda x: res, lambda x: jnp.where( - x > k_[-1], + k > k_[-1], self._kValue_untrunc_asymp_func(self.beta, x, self._knorm_bis, self._r0) * (1.0 + slp / x / self._r0), res, From 55b08b1ae325d83cb8b5fc85e2ea3c83e761e8f6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Feb 2026 13:37:41 -0600 Subject: [PATCH 15/18] fix: only use interp for untraced profiles --- jax_galsim/moffat.py | 62 ++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 069323b6..892ec30d 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -107,11 +107,12 @@ def __init__( super().__init__( beta=beta, scale_radius=( - jax.lax.select( + jax.lax.cond( trunc > 0, - _MoffatCalculateSRFromHLR(half_light_radius, trunc_, beta), - half_light_radius + lambda x: _MoffatCalculateSRFromHLR(x, trunc_, beta), + lambda x: x / jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0), + half_light_radius, ) ), trunc=trunc, @@ -395,7 +396,7 @@ def _kValue_untrunc_asymp_func(beta, k, _knorm_bis, _r0): * jnp.sqrt(jnp.pi / 2 / kr0) ) - def _kValue_interp_coeffs(self): + def _kValue_untrunc_interp_coeffs(self): # this number of points gets the tests to pass # I did not investigate further. n_pts = 5000 @@ -406,14 +407,11 @@ def _kValue_interp_coeffs(self): # IDK why k_max = self._maxk * 2 k = jnp.linspace(k_min, k_max, n_pts) - vals = self._kValue_func( + vals = self._kValue_untrunc_func( self.beta, k, self._knorm_bis, self._knorm, - self._prefactor, - self._maxRrD, - self.trunc, self._r0, ) @@ -426,26 +424,6 @@ def _kValue_interp_coeffs(self): return k, vals, akima_interp_coeffs(k, vals), slp - # def _kValue_untrunc(self, k): - # """Non truncated version of _kValue""" - # k_ = jnp.where(k > 0, k, 1.0) * self._r0 - # return jnp.where( - # k > 0, - # self._knorm_bis - # * jnp.power(k_, self.beta - 1.0) - # * _Knu(self.beta - 1.0, k_), - # self._knorm, - # ) - - # def _kValue_trunc(self, k): - # """Truncated version of _kValue""" - # k_ = jnp.where(k <= 50.0, k, 50.0) * self._r0 - # return jnp.where( - # k <= 50.0, - # self._knorm * self._prefactor * _hankel(k_, self.beta, self._maxRrD), - # 0.0, - # ) - @jax.jit def _kValue(self, kpos): """computation of the Moffat response in k-space with switch of truncated/untracated case @@ -454,19 +432,29 @@ def _kValue(self, kpos): k = jnp.sqrt((kpos.x**2 + kpos.y**2)) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) - k_, vals_, coeffs, slp = self._kValue_interp_coeffs() - res = akima_interp(k, k_, vals_, coeffs, fixed_spacing=True) + + # for untruncated profiles, we interpolate and use and asymptotic + # expansion for extrapolation + def _run_untrunc(krun): + k_, vals_, coeffs, slp = self._kValue_untrunc_interp_coeffs() + res = akima_interp(krun, k_, vals_, coeffs, fixed_spacing=True) + krun = jnp.where(krun > 0, krun, k_[1]) + return jnp.where( + krun > k_[-1], + self._kValue_untrunc_asymp_func( + self.beta, krun, self._knorm_bis, self._r0 + ) + * (1.0 + slp / krun / self._r0), + res, + ) res = jax.lax.cond( self.trunc > 0, - lambda x: res, - lambda x: jnp.where( - k > k_[-1], - self._kValue_untrunc_asymp_func(self.beta, x, self._knorm_bis, self._r0) - * (1.0 + slp / x / self._r0), - res, + lambda x: Moffat._kValue_trunc_func( + self.beta, x, self._knorm, self._prefactor, self._maxRrD, self._r0 ), - jnp.where(k > 0, k, k_[1]), + lambda x: _run_untrunc(x), + k, ) return res.reshape(out_shape) From 0fc251d419f253814bfb980159fe0521c0412fdf Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Feb 2026 13:39:44 -0600 Subject: [PATCH 16/18] perf: use fewer points --- jax_galsim/moffat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 892ec30d..8384dea6 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -399,7 +399,7 @@ def _kValue_untrunc_asymp_func(beta, k, _knorm_bis, _r0): def _kValue_untrunc_interp_coeffs(self): # this number of points gets the tests to pass # I did not investigate further. - n_pts = 5000 + n_pts = 2000 k_min = 0 # this is a fudge factor to help numerical convergnce in the tests # it should not be needed in principle since the profile is not From 99df4fc60f8bc3c3778e5427ed77733d4f2068eb Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Feb 2026 15:15:04 -0600 Subject: [PATCH 17/18] fix: enable derivs for moffats with truncation --- jax_galsim/bessel.py | 3 +- jax_galsim/convolve.py | 103 ++++++++++++++++---------------- tests/jax/test_derivs_params.py | 28 ++++++++- 3 files changed, 82 insertions(+), 52 deletions(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index d947a8c7..7473ab13 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -712,6 +712,7 @@ def t3(x): # x>8 return factor * (rc * (cx + sx) - y * rs * (sx - cx)) x = jnp.abs(x) + x_ = jnp.where(x != 0, x, 1.0) return jnp.select( - [x == 0, x <= 4, x <= 8, x > 8], [1, t1(x), t2(x), t3(x)], default=x + [x == 0, x <= 4, x <= 8, x > 8], [1, t1(x_), t2(x_), t3(x_)], default=x ).reshape(orig_shape) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 6961ad39..5296e4ed 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -63,58 +63,61 @@ def __init__(self, *args, **kwargs): % kwargs.keys() ) + # MRB: we donot run these code blocks since we do not support + # real-space convolutions and they break tracing # Check whether to perform real space convolution... # Start by checking if all objects have a hard edge. - hard_edge = True - for obj in args: - if not isinstance(obj, GSObject): - raise TypeError( - "Arguments to Convolution must be GSObjects, not %s" % obj - ) - if not obj.has_hard_edges: - hard_edge = False - - if real_space is None: - # The automatic determination is to use real_space if 2 items, both with hard edges. - if len(args) <= 2: - real_space = hard_edge - else: - real_space = False - elif bool(real_space) != real_space: - raise TypeError("real_space must be a boolean") - - # Warn if doing DFT convolution for objects with hard edges - if not real_space and hard_edge: - if len(args) == 2: - galsim_warn( - "Doing convolution of 2 objects, both with hard edges. " - "This might be more accurate with `real_space=True`, " - "but this functionality has not yet been implemented in JAX-Galsim." - ) - else: - galsim_warn( - "Doing convolution where all objects have hard edges. " - "There might be some inaccuracies due to ringing in k-space." - ) - if real_space: - # Can't do real space if nobj > 2 - if len(args) > 2: - galsim_warn( - "Real-space convolution of more than 2 objects is not implemented. " - "Switching to DFT method." - ) - real_space = False - - # Also can't do real space if any object is not analytic, so check for that. - else: - for obj in args: - if not obj.is_analytic_x: - galsim_warn( - "A component to be convolved is not analytic in real space. " - "Cannot use real space convolution. Switching to DFT method." - ) - real_space = False - break + # hard_edge = True + # for obj in args: + # if not isinstance(obj, GSObject): + # raise TypeError( + # "Arguments to Convolution must be GSObjects, not %s" % obj + # ) + # if not obj.has_hard_edges: + # hard_edge = False + + # if real_space is None: + # # The automatic determination is to use real_space if 2 items, both with hard edges. + # if len(args) <= 2: + # real_space = hard_edge + # else: + # real_space = False + # elif bool(real_space) != real_space: + # raise TypeError("real_space must be a boolean") + + # # Warn if doing DFT convolution for objects with hard edges + # if not real_space and hard_edge: + # if len(args) == 2: + # galsim_warn( + # "Doing convolution of 2 objects, both with hard edges. " + # "This might be more accurate with `real_space=True`, " + # "but this functionality has not yet been implemented in JAX-Galsim." + # ) + # else: + # galsim_warn( + # "Doing convolution where all objects have hard edges. " + # "There might be some inaccuracies due to ringing in k-space." + # ) + # if real_space: + # # Can't do real space if nobj > 2 + # if len(args) > 2: + # galsim_warn( + # "Real-space convolution of more than 2 objects is not implemented. " + # "Switching to DFT method." + # ) + # real_space = False + + # # Also can't do real space if any object is not analytic, so check for that. + # else: + # for obj in args: + # if not obj.is_analytic_x: + # galsim_warn( + # "A component to be convolved is not analytic in real space. " + # "Cannot use real space convolution. Switching to DFT method." + # ) + # real_space = False + # break + # MRB: end of commented out code blocks # Save the construction parameters (as they are at this point) as attributes so they # can be inspected later if necessary. diff --git a/tests/jax/test_derivs_params.py b/tests/jax/test_derivs_params.py index 83438bc7..49afc171 100644 --- a/tests/jax/test_derivs_params.py +++ b/tests/jax/test_derivs_params.py @@ -30,7 +30,7 @@ def _run(val_): **kwargs, gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), ) - .drawImage(nx=48, ny=48, scale=0.2) + .drawImage(nx=48, ny=48, scale=0.2, method="fft") .array[24, 24] ** 2 ) @@ -41,3 +41,29 @@ def _run(val_): gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) + + +def test_deriv_params_moffat_trunc(): + val = 2.0 + eps = 1e-5 + + def _run(val_): + return jnp.max( + jgs.Moffat( + 2.5, + scale_radius=val_, + trunc=20.0, + gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), + ) + .drawImage(nx=48, ny=48, scale=0.2) + .array[24, 24] + ** 2 + ) + + gfunc = jax.jit(jax.grad(_run)) + with jax.disable_jit(), jax.debug_nans(): + gval = gfunc(val) + + gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps + + np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) From ecba9da3ca47e3489e0dd246b67bd25e7e2d0c1a Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Feb 2026 15:22:16 -0600 Subject: [PATCH 18/18] fix: performance regression for moffat init and test derivs with trunc --- jax_galsim/moffat.py | 7 +++---- tests/jax/test_derivs_params.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 8384dea6..cd3d3593 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -107,12 +107,11 @@ def __init__( super().__init__( beta=beta, scale_radius=( - jax.lax.cond( + jnp.where( trunc > 0, - lambda x: _MoffatCalculateSRFromHLR(x, trunc_, beta), - lambda x: x + _MoffatCalculateSRFromHLR(half_light_radius, trunc_, beta), + half_light_radius / jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0), - half_light_radius, ) ), trunc=trunc, diff --git a/tests/jax/test_derivs_params.py b/tests/jax/test_derivs_params.py index 49afc171..ffb33fa1 100644 --- a/tests/jax/test_derivs_params.py +++ b/tests/jax/test_derivs_params.py @@ -43,7 +43,7 @@ def _run(val_): np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) -def test_deriv_params_moffat_trunc(): +def test_deriv_params_moffat_with_trunc(): val = 2.0 eps = 1e-5 @@ -51,7 +51,7 @@ def _run(val_): return jnp.max( jgs.Moffat( 2.5, - scale_radius=val_, + half_light_radius=val_, trunc=20.0, gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), ) @@ -67,3 +67,29 @@ def _run(val_): gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) + + +def test_deriv_params_moffat_with_respect_to_trunc(): + val = 20.0 + eps = 1e-5 + + def _run(val_): + return jnp.max( + jgs.Moffat( + 2.5, + half_light_radius=2.0, + trunc=val_, + gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), + ) + .drawImage(nx=48, ny=48, scale=0.2) + .array[24, 24] + ** 2 + ) + + gfunc = jax.jit(jax.grad(_run)) + with jax.disable_jit(), jax.debug_nans(): + gval = gfunc(val) + + gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps + + np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6)