Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __repr__(self):
elif self == arcsec:
return "galsim.arcsec"
else:
return "galsim.AngleUnit(%r)" % ensure_hashable(self.value)
return "galsim.AngleUnit(%r)" % (ensure_hashable(self.value),)

def __eq__(self, other):
return isinstance(other, AngleUnit) and jnp.array_equal(self.value, other.value)
Expand Down Expand Up @@ -222,7 +222,7 @@ def __str__(self):
return str(ensure_hashable(self._rad)) + " radians"

def __repr__(self):
return "galsim.Angle(%r, galsim.radians)" % ensure_hashable(self.rad)
return "galsim.Angle(%r, galsim.radians)" % (ensure_hashable(self.rad),)

def __eq__(self, other):
return isinstance(other, Angle) and jnp.array_equal(self.rad, other.rad)
Expand Down
135 changes: 83 additions & 52 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import (
cast_to_float,
cast_to_int,
ensure_hashable,
implements,
)
Expand Down Expand Up @@ -84,13 +81,9 @@ def _parse_args(self, *args, **kwargs):
if kwargs:
raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys())

# for simple inputs, we can check if the bounds are valid
if (
isinstance(self.xmin, (float, int))
and isinstance(self.xmax, (float, int))
and isinstance(self.ymin, (float, int))
and isinstance(self.ymax, (float, int))
and ((self.xmin > self.xmax) or (self.ymin > self.ymax))
if not (
float(self.xmin) <= float(self.xmax)
and float(self.ymin) <= float(self.ymax)
):
self._isdefined = False

Expand Down Expand Up @@ -144,8 +137,8 @@ def expand(self, factor_x, factor_y=None):
dx = (self.xmax - self.xmin) * 0.5 * (factor_x - 1.0)
dy = (self.ymax - self.ymin) * 0.5 * (factor_y - 1.0)
if isinstance(self, BoundsI):
dx = jnp.ceil(dx)
dy = jnp.ceil(dy)
dx = np.ceil(dx)
dy = np.ceil(dy)
return self.withBorder(dx, dy)

def __and__(self, other):
Expand All @@ -154,11 +147,11 @@ def __and__(self, other):
if not self.isDefined() or not other.isDefined():
return self.__class__()
else:
xmin = jnp.maximum(self.xmin, other.xmin)
xmax = jnp.minimum(self.xmax, other.xmax)
ymin = jnp.maximum(self.ymin, other.ymin)
ymax = jnp.minimum(self.ymax, other.ymax)
if xmin > xmax or ymin > ymax:
xmin = np.maximum(self.xmin, other.xmin)
xmax = np.minimum(self.xmax, other.xmax)
ymin = np.maximum(self.ymin, other.ymin)
ymax = np.minimum(self.ymax, other.ymax)
if (xmin > xmax) or (ymin > ymax):
return self.__class__()
else:
return self.__class__(xmin, xmax, ymin, ymax)
Expand All @@ -168,19 +161,19 @@ def __add__(self, other):
if not other.isDefined():
return self
elif self.isDefined():
xmin = jnp.minimum(self.xmin, other.xmin)
xmax = jnp.maximum(self.xmax, other.xmax)
ymin = jnp.minimum(self.ymin, other.ymin)
ymax = jnp.maximum(self.ymax, other.ymax)
xmin = np.minimum(self.xmin, other.xmin)
xmax = np.maximum(self.xmax, other.xmax)
ymin = np.minimum(self.ymin, other.ymin)
ymax = np.maximum(self.ymax, other.ymax)
return self.__class__(xmin, xmax, ymin, ymax)
else:
return other
elif isinstance(other, self._pos_class):
if self.isDefined():
xmin = jnp.minimum(self.xmin, other.x)
xmax = jnp.maximum(self.xmax, other.x)
ymin = jnp.minimum(self.ymin, other.y)
ymax = jnp.maximum(self.ymax, other.y)
xmin = np.minimum(self.xmin, other.x)
xmax = np.maximum(self.xmax, other.x)
ymin = np.minimum(self.ymin, other.y)
ymax = np.maximum(self.ymax, other.y)
return self.__class__(xmin, xmax, ymin, ymax)
else:
return self.__class__(other)
Expand Down Expand Up @@ -229,18 +222,18 @@ def tree_flatten(self):
"""This function flattens the Bounds into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = ()
# Define auxiliary static data that doesn’t need to be traced
if self.isDefined():
children = (self.xmin, self.xmax, self.ymin, self.ymax)
aux_data = (self.xmin, self.xmax, self.ymin, self.ymax)
else:
children = tuple()
# Define auxiliary static data that doesn’t need to be traced
aux_data = None
aux_data = ()
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
return cls(*children)
return cls(*aux_data)

@classmethod
def from_galsim(cls, galsim_bounds):
Expand Down Expand Up @@ -291,15 +284,37 @@ class BoundsD(Bounds):

def __init__(self, *args, **kwargs):
self._parse_args(*args, **kwargs)
self.xmin = cast_to_float(self.xmin)
self.xmax = cast_to_float(self.xmax)
self.ymin = cast_to_float(self.ymin)
self.ymax = cast_to_float(self.ymax)
self.xmin = float(self.xmin)
self.xmax = float(self.xmax)
self.ymin = float(self.ymin)
self.ymax = float(self.ymax)

if not (
isinstance(
self.xmin,
(float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64),
)
and isinstance(
self.xmax,
(float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64),
)
and isinstance(
self.ymin,
(float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64),
)
and isinstance(
self.ymax,
(float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64),
)
):
raise ValueError(
"BoundsI/D classes must use python ints/floats or numpy values in JAX-GalSim!"
)

def _check_scalar(self, x, name):
try:
if (
isinstance(x, jax.Array)
isinstance(x, np.ndarray)
and x.shape == ()
and x.dtype.name in ["float32", "float64", "float"]
):
Expand All @@ -325,30 +340,46 @@ class BoundsI(Bounds):

def __init__(self, *args, **kwargs):
self._parse_args(*args, **kwargs)
# for simple inputs, we can check if the bounds are valid ints

if (
isinstance(self.xmin, (float, int))
and isinstance(self.xmax, (float, int))
and isinstance(self.ymin, (float, int))
and isinstance(self.ymax, (float, int))
and (
self.xmin != int(self.xmin)
or self.xmax != int(self.xmax)
or self.ymin != int(self.ymin)
or self.ymax != int(self.ymax)
)
self.xmin != int(self.xmin)
or self.xmax != int(self.xmax)
or self.ymin != int(self.ymin)
or self.ymax != int(self.ymax)
):
raise TypeError("BoundsI must be initialized with integer values")

self.xmin = cast_to_int(self.xmin)
self.xmax = cast_to_int(self.xmax)
self.ymin = cast_to_int(self.ymin)
self.ymax = cast_to_int(self.ymax)
self.xmin = int(self.xmin)
self.xmax = int(self.xmax)
self.ymin = int(self.ymin)
self.ymax = int(self.ymax)

if not (
isinstance(
self.xmin,
(float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64),
)
and isinstance(
self.xmax,
(float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64),
)
and isinstance(
self.ymin,
(float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64),
)
and isinstance(
self.ymax,
(float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64),
)
):
raise ValueError(
"BoundsI/D classes must use python ints/floats or numpy values in JAX-GalSim!"
)

def _check_scalar(self, x, name):
try:
if (
isinstance(x, jax.Array)
isinstance(x, np.ndarray)
and x.shape == ()
and x.dtype.name in ["int32", "int64", "int"]
):
Expand Down
6 changes: 3 additions & 3 deletions jax_galsim/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __str__(self):
ensure_hashable(self.height),
)
if self.flux != 1.0:
s += ", flux=%s" % ensure_hashable(self.flux)
s += ", flux=%s" % (ensure_hashable(self.flux),)
s += ")"
return s

Expand Down Expand Up @@ -146,9 +146,9 @@ def __repr__(self):
)

def __str__(self):
s = "galsim.Pixel(scale=%s" % ensure_hashable(self.scale)
s = "galsim.Pixel(scale=%s" % (ensure_hashable(self.scale),)
if self.flux != 1.0:
s += ", flux=%s" % ensure_hashable(self.flux)
s += ", flux=%s" % (ensure_hashable(self.flux),)
s += ")"
return s

Expand Down
14 changes: 3 additions & 11 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def cast_to_float(x):
return float(x)
except Exception:
try:
return jnp.asarray(x, dtype=float)
return jnp.astype(x, dtype=float)
except Exception:
# this will return the same value for anything float-like that
# cannot be cast to float
Expand All @@ -115,20 +115,12 @@ def cast_to_int(x):
return int(x)
except Exception:
try:
if not jnp.any(jnp.isnan(x)):
return jnp.asarray(x, dtype=int)
else:
# this will return the same value for anything int-like that
# cannot be cast to int
# however, it will raise an error if something is not int-like
if type(x) is object:
return x
else:
return 1 * x
return jnp.astype(x, dtype=int)
except Exception:
# this will return the same value for anything int-like that
# cannot be cast to int
# however, it will raise an error if something is not int-like
# we exclude object types since they are used in JAX tracing
if type(x) is object:
return x
else:
Expand Down
4 changes: 2 additions & 2 deletions jax_galsim/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def __repr__(self):
)

def __str__(self):
s = "galsim.Exponential(scale_radius=%s" % ensure_hashable(self.scale_radius)
s += ", flux=%s" % ensure_hashable(self.flux)
s = "galsim.Exponential(scale_radius=%s" % (ensure_hashable(self.scale_radius),)
s += ", flux=%s" % (ensure_hashable(self.flux),)
s += ")"
return s

Expand Down
4 changes: 2 additions & 2 deletions jax_galsim/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def __repr__(self):
)

def __str__(self):
s = "galsim.Gaussian(sigma=%s" % ensure_hashable(self.sigma)
s += ", flux=%s" % ensure_hashable(self.flux)
s = "galsim.Gaussian(sigma=%s" % (ensure_hashable(self.sigma),)
s += ", flux=%s" % (ensure_hashable(self.flux),)
s += ")"
return s

Expand Down
4 changes: 2 additions & 2 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def __str__(self):
ensure_hashable(self.scale_radius),
)
if self.trunc != 0.0:
s += ", trunc=%s" % ensure_hashable(self.trunc)
s += ", trunc=%s" % (ensure_hashable(self.trunc),)
if self.flux != 1.0:
s += ", flux=%s" % ensure_hashable(self.flux)
s += ", flux=%s" % (ensure_hashable(self.flux),)
s += ")"
return s

Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/spergel.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def __str__(self):
ensure_hashable(self.half_light_radius),
)
if self.flux != 1.0:
s += ", flux=%s" % ensure_hashable(self.flux)
s += ", flux=%s" % (ensure_hashable(self.flux),)
s += ")"
return s

Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __str__(self):
ensure_hashable(self._offset.y),
)
if self._flux_ratio != 1.0:
s += " * %s" % ensure_hashable(self._flux_ratio)
s += " * %s" % (ensure_hashable(self._flux_ratio),)
return s

@property
Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def __eq__(self, other):
)

def __repr__(self):
return "galsim.PixelScale(%r)" % ensure_hashable(self.scale)
return "galsim.PixelScale(%r)" % (ensure_hashable(self.scale),)

def __hash__(self):
return hash(repr(self))
Expand Down
Loading
Loading