diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index ed513315..fad56976 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -89,7 +89,7 @@ def __repr__(self): elif self == arcsec: return "galsim.arcsec" else: - return "galsim.AngleUnit(%r)" % ensure_hashable(self.value) + return "galsim.AngleUnit(%r)" % (ensure_hashable(self.value),) def __eq__(self, other): return isinstance(other, AngleUnit) and jnp.array_equal(self.value, other.value) @@ -222,7 +222,7 @@ def __str__(self): return str(ensure_hashable(self._rad)) + " radians" def __repr__(self): - return "galsim.Angle(%r, galsim.radians)" % ensure_hashable(self.rad) + return "galsim.Angle(%r, galsim.radians)" % (ensure_hashable(self.rad),) def __eq__(self, other): return isinstance(other, Angle) and jnp.array_equal(self.rad, other.rad) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index ed5942af..e4103493 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,11 +1,8 @@ import galsim as _galsim -import jax -import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( - cast_to_float, - cast_to_int, ensure_hashable, implements, ) @@ -84,13 +81,9 @@ def _parse_args(self, *args, **kwargs): if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) - # for simple inputs, we can check if the bounds are valid - if ( - isinstance(self.xmin, (float, int)) - and isinstance(self.xmax, (float, int)) - and isinstance(self.ymin, (float, int)) - and isinstance(self.ymax, (float, int)) - and ((self.xmin > self.xmax) or (self.ymin > self.ymax)) + if not ( + float(self.xmin) <= float(self.xmax) + and float(self.ymin) <= float(self.ymax) ): self._isdefined = False @@ -144,8 +137,8 @@ def expand(self, factor_x, factor_y=None): dx = (self.xmax - self.xmin) * 0.5 * (factor_x - 1.0) dy = (self.ymax - self.ymin) * 0.5 * (factor_y - 1.0) if isinstance(self, BoundsI): - dx = jnp.ceil(dx) - dy = jnp.ceil(dy) + dx = np.ceil(dx) + dy = np.ceil(dy) return self.withBorder(dx, dy) def __and__(self, other): @@ -154,11 +147,11 @@ def __and__(self, other): if not self.isDefined() or not other.isDefined(): return self.__class__() else: - xmin = jnp.maximum(self.xmin, other.xmin) - xmax = jnp.minimum(self.xmax, other.xmax) - ymin = jnp.maximum(self.ymin, other.ymin) - ymax = jnp.minimum(self.ymax, other.ymax) - if xmin > xmax or ymin > ymax: + xmin = np.maximum(self.xmin, other.xmin) + xmax = np.minimum(self.xmax, other.xmax) + ymin = np.maximum(self.ymin, other.ymin) + ymax = np.minimum(self.ymax, other.ymax) + if (xmin > xmax) or (ymin > ymax): return self.__class__() else: return self.__class__(xmin, xmax, ymin, ymax) @@ -168,19 +161,19 @@ def __add__(self, other): if not other.isDefined(): return self elif self.isDefined(): - xmin = jnp.minimum(self.xmin, other.xmin) - xmax = jnp.maximum(self.xmax, other.xmax) - ymin = jnp.minimum(self.ymin, other.ymin) - ymax = jnp.maximum(self.ymax, other.ymax) + xmin = np.minimum(self.xmin, other.xmin) + xmax = np.maximum(self.xmax, other.xmax) + ymin = np.minimum(self.ymin, other.ymin) + ymax = np.maximum(self.ymax, other.ymax) return self.__class__(xmin, xmax, ymin, ymax) else: return other elif isinstance(other, self._pos_class): if self.isDefined(): - xmin = jnp.minimum(self.xmin, other.x) - xmax = jnp.maximum(self.xmax, other.x) - ymin = jnp.minimum(self.ymin, other.y) - ymax = jnp.maximum(self.ymax, other.y) + xmin = np.minimum(self.xmin, other.x) + xmax = np.maximum(self.xmax, other.x) + ymin = np.minimum(self.ymin, other.y) + ymax = np.maximum(self.ymax, other.y) return self.__class__(xmin, xmax, ymin, ymax) else: return self.__class__(other) @@ -229,18 +222,18 @@ def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing + children = () + # Define auxiliary static data that doesn’t need to be traced if self.isDefined(): - children = (self.xmin, self.xmax, self.ymin, self.ymax) + aux_data = (self.xmin, self.xmax, self.ymin, self.ymax) else: - children = tuple() - # Define auxiliary static data that doesn’t need to be traced - aux_data = None + aux_data = () return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - return cls(*children) + return cls(*aux_data) @classmethod def from_galsim(cls, galsim_bounds): @@ -291,15 +284,37 @@ class BoundsD(Bounds): def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - self.xmin = cast_to_float(self.xmin) - self.xmax = cast_to_float(self.xmax) - self.ymin = cast_to_float(self.ymin) - self.ymax = cast_to_float(self.ymax) + self.xmin = float(self.xmin) + self.xmax = float(self.xmax) + self.ymin = float(self.ymin) + self.ymax = float(self.ymax) + + if not ( + isinstance( + self.xmin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.xmax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + ): + raise ValueError( + "BoundsI/D classes must use python ints/floats or numpy values in JAX-GalSim!" + ) def _check_scalar(self, x, name): try: if ( - isinstance(x, jax.Array) + isinstance(x, np.ndarray) and x.shape == () and x.dtype.name in ["float32", "float64", "float"] ): @@ -325,30 +340,46 @@ class BoundsI(Bounds): def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - # for simple inputs, we can check if the bounds are valid ints + if ( - isinstance(self.xmin, (float, int)) - and isinstance(self.xmax, (float, int)) - and isinstance(self.ymin, (float, int)) - and isinstance(self.ymax, (float, int)) - and ( - self.xmin != int(self.xmin) - or self.xmax != int(self.xmax) - or self.ymin != int(self.ymin) - or self.ymax != int(self.ymax) - ) + self.xmin != int(self.xmin) + or self.xmax != int(self.xmax) + or self.ymin != int(self.ymin) + or self.ymax != int(self.ymax) ): raise TypeError("BoundsI must be initialized with integer values") - self.xmin = cast_to_int(self.xmin) - self.xmax = cast_to_int(self.xmax) - self.ymin = cast_to_int(self.ymin) - self.ymax = cast_to_int(self.ymax) + self.xmin = int(self.xmin) + self.xmax = int(self.xmax) + self.ymin = int(self.ymin) + self.ymax = int(self.ymax) + + if not ( + isinstance( + self.xmin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.xmax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + ): + raise ValueError( + "BoundsI/D classes must use python ints/floats or numpy values in JAX-GalSim!" + ) def _check_scalar(self, x, name): try: if ( - isinstance(x, jax.Array) + isinstance(x, np.ndarray) and x.shape == () and x.dtype.name in ["int32", "int64", "int"] ): diff --git a/jax_galsim/box.py b/jax_galsim/box.py index 95b3d373..6a6ccde9 100644 --- a/jax_galsim/box.py +++ b/jax_galsim/box.py @@ -62,7 +62,7 @@ def __str__(self): ensure_hashable(self.height), ) if self.flux != 1.0: - s += ", flux=%s" % ensure_hashable(self.flux) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s @@ -146,9 +146,9 @@ def __repr__(self): ) def __str__(self): - s = "galsim.Pixel(scale=%s" % ensure_hashable(self.scale) + s = "galsim.Pixel(scale=%s" % (ensure_hashable(self.scale),) if self.flux != 1.0: - s += ", flux=%s" % ensure_hashable(self.flux) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index c4cc5413..4c9a2091 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -97,7 +97,7 @@ def cast_to_float(x): return float(x) except Exception: try: - return jnp.asarray(x, dtype=float) + return jnp.astype(x, dtype=float) except Exception: # this will return the same value for anything float-like that # cannot be cast to float @@ -115,20 +115,12 @@ def cast_to_int(x): return int(x) except Exception: try: - if not jnp.any(jnp.isnan(x)): - return jnp.asarray(x, dtype=int) - else: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - if type(x) is object: - return x - else: - return 1 * x + return jnp.astype(x, dtype=int) except Exception: # this will return the same value for anything int-like that # cannot be cast to int # however, it will raise an error if something is not int-like + # we exclude object types since they are used in JAX tracing if type(x) is object: return x else: diff --git a/jax_galsim/exponential.py b/jax_galsim/exponential.py index cbb716ad..1d472544 100644 --- a/jax_galsim/exponential.py +++ b/jax_galsim/exponential.py @@ -92,8 +92,8 @@ def __repr__(self): ) def __str__(self): - s = "galsim.Exponential(scale_radius=%s" % ensure_hashable(self.scale_radius) - s += ", flux=%s" % ensure_hashable(self.flux) + s = "galsim.Exponential(scale_radius=%s" % (ensure_hashable(self.scale_radius),) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/jax_galsim/gaussian.py b/jax_galsim/gaussian.py index 3b937a11..f9c1616d 100644 --- a/jax_galsim/gaussian.py +++ b/jax_galsim/gaussian.py @@ -107,8 +107,8 @@ def __repr__(self): ) def __str__(self): - s = "galsim.Gaussian(sigma=%s" % ensure_hashable(self.sigma) - s += ", flux=%s" % ensure_hashable(self.flux) + s = "galsim.Gaussian(sigma=%s" % (ensure_hashable(self.sigma),) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 22d4e16e..2a9b312b 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -232,9 +232,9 @@ def __str__(self): ensure_hashable(self.scale_radius), ) if self.trunc != 0.0: - s += ", trunc=%s" % ensure_hashable(self.trunc) + s += ", trunc=%s" % (ensure_hashable(self.trunc),) if self.flux != 1.0: - s += ", flux=%s" % ensure_hashable(self.flux) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index b4b17904..fbc756db 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -387,7 +387,7 @@ def __str__(self): ensure_hashable(self.half_light_radius), ) if self.flux != 1.0: - s += ", flux=%s" % ensure_hashable(self.flux) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index d7d23838..bf9f4f6d 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -232,7 +232,7 @@ def __str__(self): ensure_hashable(self._offset.y), ) if self._flux_ratio != 1.0: - s += " * %s" % ensure_hashable(self._flux_ratio) + s += " * %s" % (ensure_hashable(self._flux_ratio),) return s @property diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index fcddf18e..ee18ef33 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -919,7 +919,7 @@ def __eq__(self, other): ) def __repr__(self): - return "galsim.PixelScale(%r)" % ensure_hashable(self.scale) + return "galsim.PixelScale(%r)" % (ensure_hashable(self.scale),) def __hash__(self): return hash(repr(self)) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index e79f320c..e8d4b0d1 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -498,6 +498,8 @@ def _reg_sfun(g1): jnp.array(0.2), jnp.array(4.0), jnp.array(-0.5), jnp.array(4.7) ), jax_galsim.BoundsI(jnp.array(-10), jnp.array(5), jnp.array(0), jnp.array(7)), + jax_galsim.BoundsD(0.2, 4.0, -0.5, 4.7), + jax_galsim.BoundsI(-10, 5, 0, 7), ], ) def test_api_bounds(obj): @@ -505,43 +507,11 @@ def test_api_bounds(obj): _run_object_checks(obj, obj.__class__, "pickle-eval-repr") _run_object_checks(obj, obj.__class__, "to-from-galsim") + assert isinstance(obj.xmin, (float, int)) + # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj - if isinstance(obj, jax_galsim.BoundsD): - - def _reg_sfun(g1): - return ( - ( - obj.__class__(g1, g1 + 0.5, 2 * g1, 2 * g1 + 0.5).expand(0.5) - + obj.__class__(-g1, -g1 + 0.5, -2 * g1, -2 * g1 + 0.5) - ) - .expand(4) - .area() - ) - - _sfun = jax.jit(_reg_sfun) - - _sgradfun = jax.jit(jax.grad(_sfun)) - _sfun_vmap = jax.jit(jax.vmap(_sfun)) - _sgradfun_vmap = jax.jit(jax.vmap(_sgradfun)) - - # we can jit the object - np.testing.assert_allclose(_sfun(0.3), _reg_sfun(0.3)) - - # check derivs - eps = 1e-6 - grad = _sgradfun(0.3) - finite_diff = (_reg_sfun(0.3 + eps) - _reg_sfun(0.3 - eps)) / (2 * eps) - np.testing.assert_allclose(grad, finite_diff) - - # check vmap - x = jnp.linspace(-0.9, 0.9, 10) - np.testing.assert_allclose(_sfun_vmap(x), [_reg_sfun(_x) for _x in x]) - - # check vmap grad - np.testing.assert_allclose(_sgradfun_vmap(x), [_sgradfun(_x) for _x in x]) - @pytest.mark.parametrize( "obj", diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py new file mode 100644 index 00000000..a3c832a0 --- /dev/null +++ b/tests/jax/test_render_scene.py @@ -0,0 +1,374 @@ +from functools import partial + +import galsim as _galsim +import jax +import jax.numpy as jnp +import jax.random as jrng +import numpy as np +import pytest + +import jax_galsim as jgs + + +def _generate_image_one(rng_key, psf): + rng_key, use_key = jrng.split(rng_key) + flux = jrng.uniform(use_key, minval=1.5, maxval=2.5) + rng_key, use_key = jrng.split(rng_key) + hlr = jrng.uniform(use_key, minval=0.5, maxval=2.5) + rng_key, use_key = jrng.split(rng_key) + g1 = jrng.uniform(use_key, minval=-0.1, maxval=0.1) + rng_key, use_key = jrng.split(rng_key) + g2 = jrng.uniform(use_key, minval=-0.1, maxval=0.1) + + rng_key, use_key = jrng.split(rng_key) + dx = jrng.uniform(use_key, minval=-10, maxval=10) + rng_key, use_key = jrng.split(rng_key) + dy = jrng.uniform(use_key, minval=-10, maxval=10) + + return ( + jgs.Convolve( + [ + jgs.Exponential(half_light_radius=hlr) + .shear(g1=g1, g2=g2) + .shift(dx, dy) + .withFlux(flux), + psf, + ] + ) + .withGSParams(minimum_fft_size=1024, maximum_fft_size=1024) + .drawImage(nx=200, ny=200, scale=0.2) + ) + + +@partial(jax.jit, static_argnames=("n_obj")) +def _generate_image(rng_key, psf, n_obj): + use_keys = jrng.split(rng_key, num=n_obj + 1) + rng_key = use_keys[0] + use_keys = use_keys[1:] + + return jax.vmap(_generate_image_one, in_axes=(0, None))(use_keys, psf) + + +def test_render_scene_draw_many_ffts_full_img(): + psf = jgs.Gaussian(fwhm=0.9) + img = _generate_image(jrng.key(10), psf, 5) + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(img.array.sum(axis=0)) + pdb.set_trace() + + assert img.array.shape == (5, 200, 200) + assert img.array.sum() > 5.0 + + +def _get_bd_jgs( + flux_d, + flux_b, + hlr_b, + hlr_d, + q_b, + q_d, + beta, + *, + psf_hlr=0.7, +): + components = [] + + # disk + disk = jgs.Exponential(flux=flux_d, half_light_radius=hlr_d).shear( + q=q_d, beta=beta * jgs.degrees + ) + components.append(disk) + + # bulge + bulge = jgs.Spergel(nu=-0.6, flux=flux_b, half_light_radius=hlr_b).shear( + q=q_b, beta=beta * jgs.degrees + ) + components.append(bulge) + + galaxy = jgs.Add(components) + + # psf + psf = jgs.Moffat(2, flux=1.0, half_light_radius=0.7) + + gal_conv = jgs.Convolve([galaxy, psf]) + return gal_conv + + +@partial(jax.jit, static_argnames=("fft_size", "slen")) +def _draw_stamp_jgs( + galaxy_params: dict, + image_pos: jgs.PositionD, + local_wcs: jgs.PixelScale, + fft_size: int, + slen: int, +) -> jax.Array: + gsparams = jgs.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + + convolved_object = _get_bd_jgs(**galaxy_params).withGSParams(gsparams) + + # you have to render just with on offset in order to keep the bounds + # static during rendering + # here dx,dy is the offset to the nearest pixel + # we then render with use_true_center = False to ensure the offset is + # applied relative to a pixel center for all image dimensions, including + # even ones. + # this means the object is offset by (dx,dy) from stamp.bounds.center + dx = image_pos.x - jnp.floor(image_pos.x + 0.5) + dy = image_pos.y - jnp.floor(image_pos.y + 0.5) + + stamp = convolved_object.drawImage( + nx=slen, + ny=slen, + offset=(dx, dy), + wcs=local_wcs, + dtype=jnp.float64, + use_true_center=False, + ) + + return stamp + + +@partial(jax.jit, static_argnames=("slen",)) +def _add_to_image(carry, x, slen): + image = carry[0] + stamp, image_pos = x + + # then we apply a shift to the stamp get the correct final location + # above we rendered at the location xs, ys = (dx,dy) + stamp.bounds.center + # in the image.bounds coordinates, the location (xs,ys) should be + # + # (xs - stamp.bounds.xmin) + shift.x = image_pos.x - image.bounds.xmin + # + # the logic here is that the offset of the object in array indices in the final + # image should be equal to the shift in array indices of the stamo plus the offset + # in array indicies of the stamp. + # we then get for x + # shift.x = image_pos.x - image.bounds.xmin - xs + stamp.bounds.xmin + # = image_pos.x - dx - stamp.bounds.center.x + stamp.bounds.xmin - image.bounds.xmin + # = image_pos.x - (image_pos.x - jnp.floor(image_pos.x + 0.5)) - stamp.bounds.center.x + stamp.bounds.xmin - image.bounds.xmin + # = jnp.floor(image_pos.x + 0.5) - stamp.bounds.center.x + stamp.bounds.xmin - image.bounds.xmin + shift = jgs.PositionI( + jnp.int32( + jnp.floor(image_pos.x + 0.5) + - stamp.bounds.center.x + + stamp.bounds.xmin + - image.bounds.xmin + ), + jnp.int32( + jnp.floor(image_pos.y + 0.5) + - stamp.bounds.center.y + + stamp.bounds.ymin + - image.bounds.ymin + ), + ) + + start_inds = (shift.y, shift.x) + subim = jax.lax.dynamic_slice(image.array, start_inds, (slen, slen)) + subim = subim + stamp.array + + image._array = jax.lax.dynamic_update_slice( + image.array, + subim, + start_inds, + ) + + return (image,), None + + +def _get_bd_gs( + flux_d, + flux_b, + hlr_b, + hlr_d, + q_b, + q_d, + beta, + *, + psf_hlr=0.7, +): + components = [] + + # disk + disk = _galsim.Exponential(flux=flux_d, half_light_radius=hlr_d).shear( + q=q_d, beta=beta * _galsim.degrees + ) + components.append(disk) + + # bulge + bulge = _galsim.Spergel(nu=-0.6, flux=flux_b, half_light_radius=hlr_b).shear( + q=q_b, beta=beta * _galsim.degrees + ) + components.append(bulge) + + galaxy = _galsim.Add(components) + + # psf + psf = _galsim.Moffat(2, flux=1.0, half_light_radius=0.7) + + gal_conv = _galsim.Convolve([galaxy, psf]) + return gal_conv + + +def _render_scene_stamps_galsim( + galaxy_params: dict, + image_pos: list[_galsim.PositionD], + local_wcs: list[_galsim.PixelScale], + fft_size: int, + slen: int, + image: _galsim.ImageD, + ng: int, +): + gsparams = _galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + + for i in range(ng): + gpars = {k: v[i] for k, v in galaxy_params.items()} + convolved_object = _get_bd_gs(**gpars).withGSParams(gsparams) + + stamp = convolved_object.drawImage( + nx=slen, + ny=slen, + center=(image_pos[i].x, image_pos[i].y), + wcs=local_wcs[i], + dtype=np.float64, + ) + + b = stamp.bounds & image.bounds + if b.isDefined(): + image[b] += stamp[b] + + return image + + +@pytest.mark.parametrize("slen", [51, 52]) +def test_render_scene_stamps(slen): + image = jgs.Image(ncol=200, nrow=200, scale=0.2, dtype=jnp.float64) + wcs = image.wcs + + rng = np.random.default_rng(seed=10) + ng = 5 + fft_size = 2048 + + galaxy_params = { + "flux_d": rng.uniform(low=0, high=1.0, size=ng), + "flux_b": rng.uniform(low=0, high=1.0, size=ng), + "hlr_b": rng.uniform(low=0.3, high=0.5, size=ng), + "hlr_d": rng.uniform(low=0.5, high=0.7, size=ng), + "q_b": rng.uniform(low=0.1, high=0.9, size=ng), + "q_d": rng.uniform(low=0.1, high=0.9, size=ng), + "beta": rng.uniform(low=0, high=360, size=ng), + "x": rng.uniform(low=10, high=190, size=ng), + "y": rng.uniform(low=10, high=190, size=ng), + } + + x = galaxy_params.pop("x") + y = galaxy_params.pop("y") + image_positions = jax.vmap(lambda x, y: jgs.PositionD(x=x, y=y))(x, y) + local_wcss = jax.vmap(lambda x: wcs.local(image_pos=x))(image_positions) + + stamps = jax.jit(jax.vmap(partial(_draw_stamp_jgs, slen=slen, fft_size=fft_size)))( + galaxy_params, image_positions, local_wcss + ) + assert stamps.array.shape == (ng, slen, slen) + assert stamps.array.sum() > 0 + + pad_image = jgs.ImageD( + jnp.pad(image.array, slen), wcs=image.wcs, bounds=image.bounds.withBorder(slen) + ) + + final_pad_image = jax.lax.scan( + partial(_add_to_image, slen=slen), + (pad_image,), + xs=(stamps, image_positions), + length=ng, + )[0][0] + np.testing.assert_allclose(final_pad_image.array.sum(), stamps.array.sum()) + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(final_pad_image.array) + pdb.set_trace() + + gs_image = _galsim.Image(ncol=200, nrow=200, scale=0.2, dtype=np.float64) + wcs = gs_image.wcs + + gs_image_positions = list( + map(lambda tup: _galsim.PositionD(x=tup[0], y=tup[1]), zip(x, y)) + ) + gs_local_wcss = list(map(lambda x: wcs.local(image_pos=x), gs_image_positions)) + + _render_scene_stamps_galsim( + galaxy_params, + gs_image_positions, + gs_local_wcss, + fft_size, + slen, + gs_image, + ng, + ) + + gs_image_mo = _galsim.Image(ncol=200, nrow=200, scale=0.2, dtype=np.float64) + wcs = gs_image.wcs + + gs_image_positions = list( + map(lambda tup: _galsim.PositionD(x=tup[0], y=tup[1]), zip(x, y)) + ) + gs_local_wcss = list(map(lambda x: wcs.local(image_pos=x), gs_image_positions)) + + _render_scene_stamps_galsim( + galaxy_params, + gs_image_positions, + gs_local_wcss, + fft_size, + slen + 1, + gs_image_mo, + ng, + ) + + abs_eps = 4.0 * np.max(np.abs(gs_image_mo.array - gs_image.array)) + rel_eps = 0.0 + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(gs_image_mo.array - gs_image.array) + pdb.set_trace() + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(gs_image.array) + pdb.set_trace() + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(final_pad_image.array[slen:-slen, slen:-slen] - gs_image.array) + pdb.set_trace() + + np.testing.assert_allclose( + final_pad_image.array[slen:-slen, slen:-slen].sum(), + gs_image.array.sum(), + atol=abs_eps, + rtol=rel_eps, + ) + + np.testing.assert_allclose( + final_pad_image.array[slen:-slen, slen:-slen], + gs_image.array, + atol=abs_eps, + rtol=rel_eps, + ) diff --git a/tests/jax/test_vmapping.py b/tests/jax/test_vmapping.py index 6b8d7c40..4a3e6b31 100644 --- a/tests/jax/test_vmapping.py +++ b/tests/jax/test_vmapping.py @@ -141,25 +141,6 @@ def test_eq(self, other): assert test_eq(obj_duplicated, obj) -def test_bounds_vmapping(): - obj = galsim.BoundsD(0.0, 1.0, 0.0, 1.0) - obj_d = jax.vmap(galsim.BoundsD)(0.0 * e, 1.0 * e, 0.0 * e, 1.0 * e) - - objI = galsim.BoundsI(0.0, 1.0, 0.0, 1.0) - objI_d = jax.vmap(galsim.BoundsI)(0.0 * e, 1.0 * e, 0.0 * e, 1.0 * e) - - def test_eq(self, other): - return ( - (self.xmin == jnp.array([other.xmin, other.xmin])).all() - and (self.xmax == jnp.array([other.xmax, other.xmax])).all() - and (self.ymin == jnp.array([other.ymin, other.ymin])).all() - and (self.ymax == jnp.array([other.ymax, other.ymax])).all() - ) - - assert test_eq(obj_d, obj) - assert test_eq(objI_d, objI) - - def test_drawing_vmapping_and_jitting_gaussian_psf(): gsparams = galsim.GSParams(minimum_fft_size=512, maximum_fft_size=512)