Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions jax_galsim/bessel.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,9 @@ def _temme_series_kve(v, z):
z_sq = z * z
logzo2 = jnp.log(z / 2.0)
mu = -v * logzo2
sinc_v = jnp.where(v == 0.0, 1.0, jnp.sin(jnp.pi * v) / (jnp.pi * v))
sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu) / mu)
sinc_v = jnp.sinc(v)
mu_safe = jnp.where(mu != 0, mu, 1.0)
sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu_safe) / mu_safe)

initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v
initial_p = 0.5 * jnp.exp(mu) / gamma1pv_inv
Expand Down Expand Up @@ -711,6 +712,7 @@ def t3(x): # x>8
return factor * (rc * (cx + sx) - y * rs * (sx - cx))

x = jnp.abs(x)
x_ = jnp.where(x != 0, x, 1.0)
return jnp.select(
[x == 0, x <= 4, x <= 8, x > 8], [1, t1(x), t2(x), t3(x)], default=x
[x == 0, x <= 4, x <= 8, x > 8], [1, t1(x_), t2(x_), t3(x_)], default=x
).reshape(orig_shape)
103 changes: 53 additions & 50 deletions jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,58 +63,61 @@ def __init__(self, *args, **kwargs):
% kwargs.keys()
)

# MRB: we donot run these code blocks since we do not support
# real-space convolutions and they break tracing
# Check whether to perform real space convolution...
# Start by checking if all objects have a hard edge.
hard_edge = True
for obj in args:
if not isinstance(obj, GSObject):
raise TypeError(
"Arguments to Convolution must be GSObjects, not %s" % obj
)
if not obj.has_hard_edges:
hard_edge = False

if real_space is None:
# The automatic determination is to use real_space if 2 items, both with hard edges.
if len(args) <= 2:
real_space = hard_edge
else:
real_space = False
elif bool(real_space) != real_space:
raise TypeError("real_space must be a boolean")

# Warn if doing DFT convolution for objects with hard edges
if not real_space and hard_edge:
if len(args) == 2:
galsim_warn(
"Doing convolution of 2 objects, both with hard edges. "
"This might be more accurate with `real_space=True`, "
"but this functionality has not yet been implemented in JAX-Galsim."
)
else:
galsim_warn(
"Doing convolution where all objects have hard edges. "
"There might be some inaccuracies due to ringing in k-space."
)
if real_space:
# Can't do real space if nobj > 2
if len(args) > 2:
galsim_warn(
"Real-space convolution of more than 2 objects is not implemented. "
"Switching to DFT method."
)
real_space = False

# Also can't do real space if any object is not analytic, so check for that.
else:
for obj in args:
if not obj.is_analytic_x:
galsim_warn(
"A component to be convolved is not analytic in real space. "
"Cannot use real space convolution. Switching to DFT method."
)
real_space = False
break
# hard_edge = True
# for obj in args:
# if not isinstance(obj, GSObject):
# raise TypeError(
# "Arguments to Convolution must be GSObjects, not %s" % obj
# )
# if not obj.has_hard_edges:
# hard_edge = False

# if real_space is None:
# # The automatic determination is to use real_space if 2 items, both with hard edges.
# if len(args) <= 2:
# real_space = hard_edge
# else:
# real_space = False
# elif bool(real_space) != real_space:
# raise TypeError("real_space must be a boolean")

# # Warn if doing DFT convolution for objects with hard edges
# if not real_space and hard_edge:
# if len(args) == 2:
# galsim_warn(
# "Doing convolution of 2 objects, both with hard edges. "
# "This might be more accurate with `real_space=True`, "
# "but this functionality has not yet been implemented in JAX-Galsim."
# )
# else:
# galsim_warn(
# "Doing convolution where all objects have hard edges. "
# "There might be some inaccuracies due to ringing in k-space."
# )
# if real_space:
# # Can't do real space if nobj > 2
# if len(args) > 2:
# galsim_warn(
# "Real-space convolution of more than 2 objects is not implemented. "
# "Switching to DFT method."
# )
# real_space = False

# # Also can't do real space if any object is not analytic, so check for that.
# else:
# for obj in args:
# if not obj.is_analytic_x:
# galsim_warn(
# "A component to be convolved is not analytic in real space. "
# "Cannot use real space convolution. Switching to DFT method."
# )
# real_space = False
# break
# MRB: end of commented out code blocks

# Save the construction parameters (as they are at this point) as attributes so they
# can be inspected later if necessary.
Expand Down
6 changes: 3 additions & 3 deletions jax_galsim/core/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def akima_interp(x, xp, yp, coeffs, fixed_spacing=False):
The values of the Akima cubic spline at the points x.
"""
xp = jnp.asarray(xp)
# yp = jnp.array(yp) # unused
yp = jnp.asarray(yp)
if fixed_spacing:
dxp = xp[1] - xp[0]
i = jnp.floor((x - xp[0]) / dxp).astype(jnp.int32)
Expand All @@ -160,6 +160,6 @@ def akima_interp(x, xp, yp, coeffs, fixed_spacing=False):
dx3 = dx2 * dx
xval = a[i] + b[i] * dx + c[i] * dx2 + d[i] * dx3

xval = jnp.where(x < xp[0], 0, xval)
xval = jnp.where(x > xp[-1], 0, xval)
xval = jnp.where(x < xp[0], yp[0], xval)
xval = jnp.where(x > xp[-1], yp[-1], xval)
return xval
116 changes: 103 additions & 13 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from jax_galsim.bessel import j0, kv
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral
from jax_galsim.core.interpolate import akima_interp, akima_interp_coeffs
from jax_galsim.core.utils import bisect_for_root, ensure_hashable, implements
from jax_galsim.gsobject import GSObject
from jax_galsim.position import PositionD
from jax_galsim.random import UniformDeviate


Expand Down Expand Up @@ -103,12 +103,13 @@ def __init__(
fwhm=fwhm,
)
else:
trunc_ = jnp.where(trunc > 0, trunc, 50.0)
super().__init__(
beta=beta,
scale_radius=(
jax.lax.select(
jnp.where(
trunc > 0,
_MoffatCalculateSRFromHLR(half_light_radius, trunc, beta),
_MoffatCalculateSRFromHLR(half_light_radius, trunc_, beta),
half_light_radius
/ jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0),
)
Expand Down Expand Up @@ -281,7 +282,19 @@ def _prefactor(self):
@jax.jit
def _maxk_func(self, k):
return (
jnp.abs(self._kValue(PositionD(x=k, y=0)).real / self.flux)
jnp.abs(
self._kValue_func(
self.beta,
jnp.atleast_1d(k),
self._knorm_bis,
self._knorm,
self._prefactor,
self._maxRrD,
self.trunc,
self._r0,
)[0].real
/ self.flux
)
- self.gsparams.maxk_threshold
)

Expand Down Expand Up @@ -336,36 +349,113 @@ def _xValue(self, pos):
rsq > self._maxRrD_sq, 0.0, self._norm * jnp.power(1.0 + rsq, -self.beta)
)

def _kValue_untrunc(self, k):
@staticmethod
@jax.jit
def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm, _r0):
"""Non truncated version of _kValue"""
k_ = jnp.where(k > 0, k * _r0, 1.0)
return jnp.where(
k > 0,
self._knorm_bis * jnp.power(k, self.beta - 1.0) * _Knu(self.beta - 1.0, k),
self._knorm,
_knorm_bis * jnp.power(k_, beta - 1.0) * _Knu(beta - 1.0, k_),
_knorm,
)

def _kValue_trunc(self, k):
@staticmethod
@jax.jit
def _kValue_trunc_func(beta, k, _knorm, _prefactor, _maxRrD, _r0):
"""Truncated version of _kValue"""
k_ = k * _r0
k_ = jnp.where(k_ <= 50.0, k_, 50.0)
return jnp.where(
k <= 50.0,
self._knorm * self._prefactor * _hankel(k, self.beta, self._maxRrD),
k_ <= 50.0,
_knorm * _prefactor * _hankel(k_, beta, _maxRrD),
0.0,
)

@staticmethod
@jax.jit
def _kValue_func(beta, k, _knorm_bis, _knorm, _prefactor, _maxRrD, trunc, _r0):
return jax.lax.cond(
trunc > 0,
lambda x: Moffat._kValue_trunc_func(
beta, x, _knorm, _prefactor, _maxRrD, _r0
),
lambda x: Moffat._kValue_untrunc_func(beta, x, _knorm_bis, _knorm, _r0),
k,
)

@staticmethod
@jax.jit
def _kValue_untrunc_asymp_func(beta, k, _knorm_bis, _r0):
kr0 = k * _r0
return (
_knorm_bis
* jnp.power(kr0, beta - 1.0)
* jnp.exp(-kr0)
* jnp.sqrt(jnp.pi / 2 / kr0)
)

def _kValue_untrunc_interp_coeffs(self):
# this number of points gets the tests to pass
# I did not investigate further.
n_pts = 2000
k_min = 0
# this is a fudge factor to help numerical convergnce in the tests
# it should not be needed in principle since the profile is not
# evaluated above maxk, but it appears to be needed anyway and
# IDK why
k_max = self._maxk * 2
k = jnp.linspace(k_min, k_max, n_pts)
vals = self._kValue_untrunc_func(
self.beta,
k,
self._knorm_bis,
self._knorm,
self._r0,
)

# slope to match the interpolant onto an asymptotic expansion of kv
# that is kv(x) ~ sqrt(pi/2/x) * exp(-x) * (1 + slp/x)
aval = self._kValue_untrunc_asymp_func(
self.beta, k[-1], self._knorm_bis, self._r0
)
slp = (vals[-1] / aval - 1) * k[-1] * self._r0

return k, vals, akima_interp_coeffs(k, vals), slp

@jax.jit
def _kValue(self, kpos):
"""computation of the Moffat response in k-space with switch of truncated/untracated case
kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image)
"""
k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq)
k = jnp.sqrt((kpos.x**2 + kpos.y**2))
out_shape = jnp.shape(k)
k = jnp.atleast_1d(k)

# for untruncated profiles, we interpolate and use and asymptotic
# expansion for extrapolation
def _run_untrunc(krun):
k_, vals_, coeffs, slp = self._kValue_untrunc_interp_coeffs()
res = akima_interp(krun, k_, vals_, coeffs, fixed_spacing=True)
krun = jnp.where(krun > 0, krun, k_[1])
return jnp.where(
krun > k_[-1],
self._kValue_untrunc_asymp_func(
self.beta, krun, self._knorm_bis, self._r0
)
* (1.0 + slp / krun / self._r0),
res,
)

res = jax.lax.cond(
self.trunc > 0,
lambda x: self._kValue_trunc(x),
lambda x: self._kValue_untrunc(x),
lambda x: Moffat._kValue_trunc_func(
self.beta, x, self._knorm, self._prefactor, self._maxRrD, self._r0
),
lambda x: _run_untrunc(x),
k,
)

return res.reshape(out_shape)

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
Expand Down
2 changes: 1 addition & 1 deletion tests/GalSim
Submodule GalSim updated 1 files
+1 −4 tests/test_moffat.py
Loading
Loading