Problem
Atm, https://github.com/electronmicroscopy/quantem/blob/dev/src/quantem/imaging/drift.py hardcodes two images to compute affine/translation loss:
im0, w0 = self.interpolator[0].warp_image(
self.images[0].array,
knot_0,
)
im1, w1 = self.interpolator[1].warp_image(
self.images[1].array,
knot_1,
)
# Cross correlation alignment
shifts, image_shift = cross_correlation_shift(
im0,
im1,
...
)
cost[a0] = np.mean(np.abs(im0 - image_shift))
Proposed solution
Extend affine() to use all images (not just 2)
Sth like below:
def _compute_cost(drift):
# Warp all images with candidate drift
warped_images = []
for img_idx in range(self.num_images):
knot = self.knots[img_idx].copy()
scanline_offset = np.arange(knot.shape[1]) - (knot.shape[1] - 1) / 2
knot[0] += drift[0] * scanline_offset[:, None]
knot[1] += drift[1] * scanline_offset[:, None]
...
warped_images.append(warped)
# Compute cost: average of all images aligned to reference
ref = warped_images[0]
total_cost = 0
for img_idx in range(1, self.num_images):
_, aligned = cross_correlation_shift(
ref, warped_images[img_idx],
...
)
...
Problem
Atm, https://github.com/electronmicroscopy/quantem/blob/dev/src/quantem/imaging/drift.py hardcodes two images to compute affine/translation loss:
Proposed solution
Extend
affine()to use all images (not just 2)Sth like below: