diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index dd92b914..0ebbc569 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -7,8 +7,12 @@ from quantem.core.datastructures.dataset2d import Dataset2d from quantem.core.datastructures.dataset4d import Dataset4d -from quantem.core.datastructures.polar4dstem import dataset4dstem_polar_transform - +from quantem.core.datastructures.polar4dstem import ( + auto_origin_id as _auto_origin_id, +) +from quantem.core.datastructures.polar4dstem import ( + dataset4dstem_polar_transform as _dataset4dstem_polar_transform, +) from quantem.core.utils.validators import ensure_valid_array from quantem.core.visualization import show_2d from quantem.core.visualization.visualization_utils import ScalebarConfig @@ -74,7 +78,7 @@ def __init__( _token : object | None, optional Token to prevent direct instantiation, by default None """ - mdata_keys_4dstem = ["q_to_r_rotation_ccw_deg", 'q_transpose', "ellipticity"] + mdata_keys_4dstem = ["q_to_r_rotation_ccw_deg", "q_transpose", "ellipticity"] for k in mdata_keys_4dstem: if k not in metadata.keys(): metadata[k] = None @@ -754,5 +758,12 @@ def median_filter_masked_pixels(self, mask: np.ndarray, kernel_width: int = 3): self.array[:, :, x_min:x_max, y_min:y_max], axis=(2, 3) ) + def auto_origin_id(self, **kwargs): + """Find diffraction centers by minimizing angular intensity variation. + + Delegates to the module-level ``auto_origin_id`` function. + See its docstring for full parameter details. + """ + return _auto_origin_id(self, **kwargs) - polar_transform = dataset4dstem_polar_transform \ No newline at end of file + polar_transform = _dataset4dstem_polar_transform diff --git a/src/quantem/core/datastructures/polar4dstem.py b/src/quantem/core/datastructures/polar4dstem.py index 6619af5c..d237ed73 100644 --- a/src/quantem/core/datastructures/polar4dstem.py +++ b/src/quantem/core/datastructures/polar4dstem.py @@ -1,7 +1,12 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + import numpy as np +import torch +import torch.nn.functional as F from numpy.typing import NDArray -from typing import Any, TYPE_CHECKING -from scipy.ndimage import map_coordinates +from tqdm import tqdm if TYPE_CHECKING: from .dataset4dstem import Dataset4dstem @@ -21,6 +26,7 @@ def __init__( units: list[str] | tuple | list, signal_units: str = "arb. units", metadata: dict | None = None, + origin_array: NDArray | None = None, _token: object | None = None, ): if metadata is None: @@ -31,8 +37,6 @@ def __init__( "polar_radial_step", "polar_num_annular_bins", "polar_two_fold_rotation_symmetry", - "polar_origin_row", - "polar_origin_col", "polar_ellipse_params", ] for k in mdata_keys_polar: @@ -48,6 +52,7 @@ def __init__( metadata=metadata, _token=_token, ) + self.origin_array = origin_array @classmethod def from_array( @@ -62,7 +67,10 @@ def from_array( ) -> "Polar4dstem": array = np.asarray(array) if array.ndim != 4: - raise ValueError("Polar4dstem.from_array expects a 4D array.") + raise ValueError( + f"Found array with shape: {array.shape}. " + "Polar4dstem.from_array expects a 4D array." + ) if origin is None: origin = np.zeros(4, dtype=float) if sampling is None: @@ -91,68 +99,418 @@ def n_r(self) -> int: return int(self.array.shape[3]) -def _precompute_polar_coords( +def _to_numpy(tensor: torch.Tensor) -> NDArray: + """Convert torch tensor to numpy array.""" + return tensor.detach().cpu().numpy() + + +def _normalize_coords_for_grid_sample( + coords_y: torch.Tensor, + coords_x: torch.Tensor, + height: int, + width: int, +) -> torch.Tensor: + """ + Convert pixel coordinates to normalized [-1, 1] coordinates for grid_sample. + grid_sample expects x_norm = 2*x/(W-1) - 1, y_norm = 2*y/(H-1) - 1, + stacked as (..., 2) in [x, y] order. + """ + x_norm = 2.0 * coords_x / (width - 1) - 1.0 + y_norm = 2.0 * coords_y / (height - 1) - 1.0 + return torch.stack([x_norm, y_norm], dim=-1) + + +def _polar_to_cartesian_offsets( + phi: torch.Tensor, + r: torch.Tensor, + ellipse_params: tuple[float, float, float] | None, + device: str = "cpu", +) -> tuple[torch.Tensor, torch.Tensor]: + """Convert polar (phi, r) grids to Cartesian (x, y) offsets from the origin, + optionally correcting for elliptical distortion.""" + if ellipse_params is None: + x = r * torch.cos(phi) + y = r * torch.sin(phi) + else: + if len(ellipse_params) != 3: + raise ValueError("ellipse_params must be (a, b, theta_deg).") + a, b, theta_deg = ellipse_params + theta = torch.deg2rad(torch.tensor(theta_deg, dtype=torch.float32, device=device)) + # Rotate into the ellipse frame, scale by a/b to undo the distortion, + # then rotate back so sampling follows the true circular rings + alpha = phi - theta + u = (a / b) * r * torch.cos(alpha) + v_prime = r * torch.sin(alpha) + cos_t = torch.cos(theta) + sin_t = torch.sin(theta) + x = u * cos_t - v_prime * sin_t + y = u * sin_t + v_prime * cos_t + return x, y + + +def _build_candidate_grids( + base_x_norm: torch.Tensor, + base_y_norm: torch.Tensor, + center_row: int, + center_col: int, + margin: int, ny: int, nx: int, - origin_row: float, - origin_col: float, + x_norm_scale: float, + y_norm_scale: float, + device: str = "cpu", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build a batch of normalized sampling grids for all candidate origins + within a search window around (center_row, center_col).""" + # Enumerate all pixel positions in the search window, clamped to image bounds + rows = torch.arange( + max(0, center_row - margin), + min(ny, center_row + margin + 1), + dtype=torch.long, + device=device, + ) + cols = torch.arange( + max(0, center_col - margin), + min(nx, center_col + margin + 1), + dtype=torch.long, + device=device, + ) + row_grid, col_grid = torch.meshgrid(rows, cols, indexing="ij") + row_flat, col_flat = row_grid.reshape(-1), col_grid.reshape(-1) + # Shift the pre-computed polar offsets to each candidate origin, + # converting to grid_sample's [-1, 1] normalized coordinates + grid_x = base_x_norm.unsqueeze(0) + (col_flat.float() * x_norm_scale - 1.0)[:, None, None] + grid_y = base_y_norm.unsqueeze(0) + (row_flat.float() * y_norm_scale - 1.0)[:, None, None] + grids = torch.stack([grid_x, grid_y], dim=-1) # (N, n_phi, n_r, 2) + return row_flat, col_flat, grids + + +def _angular_std_scores( + dp_batch: torch.Tensor, + grids: torch.Tensor, + min_r_idx: int, + max_r_idx: int, +) -> torch.Tensor: + """Score candidate origins by angular std over a mid-radius band. + Lower scores indicate better centering.""" + n = grids.shape[0] + # Sample the diffraction pattern at each candidate's polar grid positions + polars = F.grid_sample( + dp_batch.expand(n, -1, -1, -1), + grids, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + # A correctly centered pattern has uniform intensity along each ring, + # so the angular std is minimized at the true center + region = polars.squeeze(1)[:, :, min_r_idx:max_r_idx] + return region.std(dim=1).sum(dim=1) + + +def _build_polar_sampling_offsets( ellipse_params: tuple[float, float, float] | None, num_annular_bins: int, radial_min: float, - radial_max: float | None, + radial_max_eff: float, radial_step: float, two_fold_rotation_symmetry: bool, -) -> tuple[NDArray, NDArray, NDArray, float]: - origin_row = float(origin_row) - origin_col = float(origin_col) + device: str = "cpu", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Build origin-independent Cartesian offsets for a polar sampling grid. + Returns (offset_x, offset_y, phi_bins, radial_bins) where offset_x and + offset_y have shape (n_phi, n_r) and represent pixel displacements from + an arbitrary origin.""" if radial_step <= 0: - raise ValueError("radial_step must be > 0.") + raise ValueError(f"Got radial_step = {radial_step}. radial_step must be > 0.") if num_annular_bins < 1: raise ValueError("num_annular_bins must be >= 1.") + + radial_bins = torch.arange( + radial_min, radial_max_eff, radial_step, dtype=torch.float32, device=device + ) + if radial_bins.numel() == 0: + radial_bins = torch.tensor([radial_min], dtype=torch.float32, device=device) + phi_range = torch.pi if two_fold_rotation_symmetry else 2.0 * torch.pi + # Drop the last endpoint because 0 and 2pi (or pi) are the same angle + phi_bins = torch.linspace( + 0.0, phi_range, num_annular_bins + 1, dtype=torch.float32, device=device + )[:-1] + phi_grid, r_grid = torch.meshgrid(phi_bins, radial_bins, indexing="ij") + # Compute offsets relative to origin (0,0) so they can be reused + # for any candidate origin by simple translation + offset_x, offset_y = _polar_to_cartesian_offsets(phi_grid, r_grid, ellipse_params, device) + return offset_x, offset_y, phi_bins, radial_bins + + +def _compute_radial_max( + ny: int, + nx: int, + origin_row: float, + origin_col: float, + radial_max: float | None, + radial_min: float, + radial_step: float, +) -> float: + """Compute the effective maximum radius, clamped to image bounds.""" + # Use the shortest distance from the origin to any image edge so the + # polar grid never samples outside the image bounds if radial_max is None: - r_row_pos = origin_row - r_row_neg = (ny - 1) - origin_row - r_col_pos = origin_col - r_col_neg = (nx - 1) - origin_col - radial_max_eff = float(min(r_row_pos, r_row_neg, r_col_pos, r_col_neg)) + radial_max_eff = float( + min( + origin_row, + (ny - 1) - origin_row, + origin_col, + (nx - 1) - origin_col, + ) + ) else: radial_max_eff = float(radial_max) + # Guarantee at least one radial bin if radial_max_eff <= radial_min: radial_max_eff = radial_min + radial_step - radial_bins = np.arange(radial_min, radial_max_eff, radial_step, dtype=np.float64) - if radial_bins.size == 0: - radial_bins = np.array([radial_min], dtype=np.float64) - if two_fold_rotation_symmetry: - phi_range = np.pi - else: - phi_range = 2.0 * np.pi - phi_bins = np.linspace(0.0, phi_range, num_annular_bins, endpoint=False, dtype=np.float64) - phi_grid, r_grid = np.meshgrid(phi_bins, radial_bins, indexing="ij") - if ellipse_params is None: - x = r_grid * np.cos(phi_grid) - y = r_grid * np.sin(phi_grid) + return radial_max_eff + + +def _precompute_polar_coords( + ny: int, + nx: int, + origin_row: float, + origin_col: float, + ellipse_params: tuple[float, float, float] | None, + num_annular_bins: int, + radial_min: float, + radial_max: float | None, + radial_step: float, + two_fold_rotation_symmetry: bool, + device: str = "cpu", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]: + """Build a normalized sampling grid for a single known origin.""" + origin_row = float(origin_row) + origin_col = float(origin_col) + # Clamp radial range so the polar grid stays within image bounds + radial_max_eff = _compute_radial_max( + ny, + nx, + origin_row, + origin_col, + radial_max, + radial_min, + radial_step, + ) + # Get origin-independent polar offsets in pixel coordinates + offset_x, offset_y, phi_bins, radial_bins = _build_polar_sampling_offsets( + ellipse_params, + num_annular_bins, + radial_min, + radial_max_eff, + radial_step, + two_fold_rotation_symmetry, + device, + ) + # Translate offsets to absolute pixel pos at this origin + coords_x = offset_x + origin_col + coords_y = offset_y + origin_row + # Convert to [-1, 1] normalized coordinates expected by grid_sample + grid = _normalize_coords_for_grid_sample(coords_y, coords_x, ny, nx) + grid = grid.unsqueeze(0) # (1, n_phi, n_r, 2) + return grid, phi_bins, radial_bins, radial_max_eff + + +def auto_origin_id( + data: "Dataset4dstem", + *, + ellipse_params: tuple[float, float, float] | None = None, + num_annular_bins: int = 180, + radial_min: float = 0.0, + radial_max: float | None = None, + radial_step: float = 1.0, + two_fold_rotation_symmetry: bool = False, + device: str = "cpu", +) -> NDArray: + """ + Automatic diffraction center finding by minimizing angular intensity + variation in the polar transform. A correctly centered diffraction + pattern has uniform intensity along each ring, so the center that + minimizes the angular standard deviation is the true beam center. + + Uses a coarse-to-fine search on the mean diffraction pattern to find + a global center, then refines per scan position to account for beam + drift across the scan. + + Parameters + ---------- + data : Dataset4dstem + A 4D-STEM dataset (or 2D diffraction pattern wrapped as 4D). + ellipse_params : tuple or None + Ellipse parameters (a, b, theta_deg) for distortion correction. + num_annular_bins : int + Number of angular bins for the final polar transform (not used + during center-finding, which uses 36 bins for speed). + radial_min : float + Minimum radius in pixels. + radial_max : float or None + Maximum radius in pixels. + radial_step : float + Radial step size in pixels. + two_fold_rotation_symmetry : bool + If True, use only 0 to pi range for angles. + device : str + Torch device for computation ("cpu", "cuda", "cuda:0", etc.). + + Returns + ------- + origin_array : np.ndarray + Array of shape (scan_y, scan_x, 2) containing (row, col) origin + estimates in pixels. + """ + if len(data.array.shape) == 2: + ny, nx = data.array.shape + scan_y, scan_x = 1, 1 + elif len(data.array.shape) == 4: + scan_y, scan_x, ny, nx = data.array.shape else: - if len(ellipse_params) != 3: - raise ValueError("ellipse_params must be (a, b, theta_deg).") - a, b, theta_deg = ellipse_params - theta = np.deg2rad(theta_deg) - alpha = phi_grid - theta - u = (a / b) * r_grid * np.cos(alpha) - v_prime = r_grid * np.sin(alpha) - cos_t = np.cos(theta) - sin_t = np.sin(theta) - x = u * cos_t - v_prime * sin_t - y = u * sin_t + v_prime * cos_t - coords_y = y + origin_row - coords_x = x + origin_col - coords = np.stack((coords_y, coords_x), axis=0) - return coords, phi_bins, radial_bins, radial_max_eff + raise ValueError( + f" Got array with shape {data.array.shape}." + "To use auto_origin_id, pass a 2D or 4DSTEM dataset." + ) + + origin_array = np.zeros((scan_y, scan_x, 2), dtype=float) + total_positions = scan_y * scan_x + + # first get COM of mean DP because it gives a robust rough center + array_4d = data.array if data.array.ndim == 4 else data.array[None, None, :, :] + mean_dp_np = array_4d.mean(axis=(0, 1)).astype(np.float32) + total_intensity = mean_dp_np.sum() + yy_grid, xx_grid = np.mgrid[0:ny, 0:nx] + com_row = int(round(float((yy_grid * mean_dp_np).sum() / total_intensity))) + com_col = int(round(float((xx_grid * mean_dp_np).sum() / total_intensity))) + + # building a fixed polar grid that is safe for all candidates + # safe_rmax ensures no candidate's grid extends outside the image + global_margin = 20 + safe_rmax = float( + min( + com_row - global_margin, + (ny - 1) - (com_row + global_margin), + com_col - global_margin, + (nx - 1) - (com_col + global_margin), + ) + ) + if radial_max is not None: + safe_rmax = min(safe_rmax, float(radial_max)) + if safe_rmax <= radial_min: + safe_rmax = radial_min + radial_step + # use very coarse binning because asymmetry is still captured at + # low angular resolution and is significantly faster + search_n_phi = 36 + offset_x, offset_y, _, radial_bins = _build_polar_sampling_offsets( + ellipse_params, + search_n_phi, + radial_min, + safe_rmax, + radial_step, + two_fold_rotation_symmetry, + device, + ) + n_r = radial_bins.numel() + min_r_idx = int(np.floor(0.1 * n_r)) + max_r_idx = int(np.ceil(0.9 * n_r)) + # Normalize offsets to [-1, 1] because grid_sample expects normalized coordinates + x_norm_scale = 2.0 / (nx - 1) + y_norm_scale = 2.0 / (ny - 1) + base_x_norm = offset_x * x_norm_scale + base_y_norm = offset_y * y_norm_scale + + # now find actual center + # Coarse search over ±global_margin around COM + coarse_step = 4 + coarse_rows = torch.arange( + max(0, com_row - global_margin), + min(ny, com_row + global_margin + 1), + coarse_step, + dtype=torch.long, + device=device, + ) + coarse_cols = torch.arange( + max(0, com_col - global_margin), + min(nx, com_col + global_margin + 1), + coarse_step, + dtype=torch.long, + device=device, + ) + # Create all (row, col) candidate pairs and flatten for batched evaluation + coarse_row_grid, coarse_col_grid = torch.meshgrid(coarse_rows, coarse_cols, indexing="ij") + coarse_row_flat, coarse_col_flat = coarse_row_grid.reshape(-1), coarse_col_grid.reshape(-1) + # Shift polar offsets to each candidate origin in normalized coordinates + coarse_gx = ( + base_x_norm.unsqueeze(0) + (coarse_col_flat.float() * x_norm_scale - 1.0)[:, None, None] + ) + coarse_gy = ( + base_y_norm.unsqueeze(0) + (coarse_row_flat.float() * y_norm_scale - 1.0)[:, None, None] + ) + coarse_grids = torch.stack([coarse_gx, coarse_gy], dim=-1) + # Score all coarse candidates on the mean DP and pick the best one + mean_dp_batch = torch.from_numpy(mean_dp_np).to(device).unsqueeze(0).unsqueeze(0) + coarse_scores = _angular_std_scores(mean_dp_batch, coarse_grids, min_r_idx, max_r_idx) + best_coarse_idx = coarse_scores.argmin().item() + coarse_best_row = int(coarse_row_flat[best_coarse_idx].item()) + coarse_best_col = int(coarse_col_flat[best_coarse_idx].item()) + + # Fine search (step=1) around coarse best for global center of mean DP + fine_margin = 6 + fine_row_flat, fine_col_flat, fine_grids = _build_candidate_grids( + base_x_norm, + base_y_norm, + coarse_best_row, + coarse_best_col, + fine_margin, + ny, + nx, + x_norm_scale, + y_norm_scale, + device, + ) + fine_scores = _angular_std_scores(mean_dp_batch, fine_grids, min_r_idx, max_r_idx) + best_fine_idx = fine_scores.argmin().item() + global_row = int(fine_row_flat[best_fine_idx].item()) + global_col = int(fine_col_flat[best_fine_idx].item()) + # Get center for each scan pos by fine search around global center + # Assuming that the center doesn't shift more than 10 pixels across the scan + local_margin = 10 + local_rf, local_cf, local_grids = _build_candidate_grids( + base_x_norm, + base_y_norm, + global_row, + global_col, + local_margin, + ny, + nx, + x_norm_scale, + y_norm_scale, + device, + ) + pbar = tqdm(total=total_positions, desc="Finding origin for each scan position") + for y_pos in range(scan_y): + row_dps = torch.from_numpy(array_4d[y_pos].astype(np.float32)).to( + device + ) # (scan_x, ny, nx) + + for x_pos in range(scan_x): + dp_batch = row_dps[x_pos].unsqueeze(0).unsqueeze(0) + scores = _angular_std_scores(dp_batch, local_grids, min_r_idx, max_r_idx) + best_idx = scores.argmin().item() + origin_array[y_pos, x_pos, 0] = local_rf[best_idx].item() + origin_array[y_pos, x_pos, 1] = local_cf[best_idx].item() + pbar.update(1) + + pbar.close() + return origin_array def dataset4dstem_polar_transform( self: "Dataset4dstem", - origin_row: float | int | NDArray, - origin_col: float | int | NDArray, + origin_array: NDArray | torch.Tensor | None = None, ellipse_params: tuple[float, float, float] | None = None, num_annular_bins: int = 180, radial_min: float = 0.0, @@ -161,42 +519,120 @@ def dataset4dstem_polar_transform( two_fold_rotation_symmetry: bool = False, name: str | None = None, signal_units: str | None = None, -) -> Polar4dstem: + scan_pos: tuple[int, int] | None = None, + device: str = "cpu", +) -> Polar4dstem | torch.Tensor: if self.array.ndim != 4: - raise ValueError("polar_transform requires a 4D-STEM dataset (ndim=4).") + raise ValueError( + f"Found array with shape: {self.array.shape}. " + "polar_transform requires a 4D-STEM dataset (ndim=4)." + ) scan_y, scan_x, ny, nx = self.array.shape - origin_row_f = float(origin_row) - origin_col_f = float(origin_col) - coords, phi_bins, radial_bins, radial_max_eff = _precompute_polar_coords( + + # Standardize origin_array input + if isinstance(origin_array, torch.Tensor): + origin_array = _to_numpy(origin_array) + origin_array = np.asarray(origin_array) if origin_array is not None else None + if origin_array is None: + center = np.array([(ny - 1) / 2.0, (nx - 1) / 2.0], dtype=float) + origins = np.broadcast_to(center, (scan_y, scan_x, 2)).copy() + elif origin_array.shape == (2,): + origins = np.empty((scan_y, scan_x, 2), dtype=float) + origins[...] = origin_array + elif origin_array.shape == (scan_y, scan_x, 2): + origins = origin_array + else: + raise ValueError( + f" Got {origin_array.shape}. " + "origin_array must have shape None, (2,) or (scan_y, scan_x, 2)." + ) + + # If scan_pos is provided, compute polar transform only for that position + if scan_pos is not None: + iy, ix = scan_pos + dp = torch.from_numpy(self.array[iy, ix].astype(np.float32)).to(device) + r0 = float(origins[iy, ix, 0]) + c0 = float(origins[iy, ix, 1]) + grid, phi_bins, radial_bins, radial_max_eff = _precompute_polar_coords( + ny=ny, + nx=nx, + origin_row=r0, + origin_col=c0, + ellipse_params=ellipse_params, + num_annular_bins=num_annular_bins, + radial_min=radial_min, + radial_max=radial_max, + radial_step=radial_step, + two_fold_rotation_symmetry=two_fold_rotation_symmetry, + device=device, + ) + dp_batch = dp.unsqueeze(0).unsqueeze(0) # (1, 1, ny, nx) + polar2d = F.grid_sample( + dp_batch, + grid, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + return polar2d.squeeze(0).squeeze(0) # (n_phi, n_r) + + # Use the global minimum safe radius across all origins so every scan + # position maps to the same-size polar grid (required for a uniform 4D output) + if radial_max is None: + r_row_pos = origins[:, :, 0] + r_row_neg = (ny - 1) - origins[:, :, 0] + r_col_pos = origins[:, :, 1] + r_col_neg = (nx - 1) - origins[:, :, 1] + radial_max_eff_array = np.minimum.reduce([r_row_pos, r_row_neg, r_col_pos, r_col_neg]) + radial_max = float(max(radial_max_eff_array.min(), radial_min + radial_step)) + + # Compute grid for first position to get output shape + grid, phi_bins, radial_bins, radial_max_eff = _precompute_polar_coords( ny=ny, nx=nx, - origin_row=origin_row_f, - origin_col=origin_col_f, + origin_row=float(origins[0, 0, 0]), + origin_col=float(origins[0, 0, 1]), ellipse_params=ellipse_params, num_annular_bins=num_annular_bins, radial_min=radial_min, radial_max=radial_max, radial_step=radial_step, two_fold_rotation_symmetry=two_fold_rotation_symmetry, + device=device, ) - n_phi = phi_bins.size - n_r = radial_bins.size - result_dtype = np.result_type(self.array.dtype, np.float32) - out = np.empty((scan_y, scan_x, n_phi, n_r), dtype=result_dtype) + n_phi = phi_bins.numel() + n_r = radial_bins.numel() + out = np.empty((scan_y, scan_x, n_phi, n_r), dtype=np.float32) for iy in range(scan_y): for ix in range(scan_x): - dp = self.array[iy, ix] - out[iy, ix] = map_coordinates( - dp, - coords, - order=1, - mode="constant", - cval=0.0, + dp = torch.from_numpy(self.array[iy, ix].astype(np.float32)).to(device) + r0 = float(origins[iy, ix, 0]) + c0 = float(origins[iy, ix, 1]) + grid, _, _, _ = _precompute_polar_coords( + ny=ny, + nx=nx, + origin_row=r0, + origin_col=c0, + ellipse_params=ellipse_params, + num_annular_bins=num_annular_bins, + radial_min=radial_min, + radial_max=radial_max, + radial_step=radial_step, + two_fold_rotation_symmetry=two_fold_rotation_symmetry, + device=device, ) - if two_fold_rotation_symmetry: - phi_range = np.pi - else: - phi_range = 2.0 * np.pi + dp_batch = dp.unsqueeze(0).unsqueeze(0) + polar2d = F.grid_sample( + dp_batch, + grid, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + out[iy, ix] = _to_numpy(polar2d.squeeze(0).squeeze(0)) + + # Express polar axes in physical units matching the input dataset's calibration + phi_range = np.pi if two_fold_rotation_symmetry else 2.0 * np.pi phi_step_deg = (phi_range / float(n_phi)) * (180.0 / np.pi) sampling = np.zeros(4, dtype=float) origin = np.zeros(4, dtype=float) @@ -220,8 +656,6 @@ def dataset4dstem_polar_transform( "polar_radial_step": float(radial_step), "polar_num_annular_bins": int(n_phi), "polar_two_fold_rotation_symmetry": bool(two_fold_rotation_symmetry), - "polar_origin_row": origin_row_f, - "polar_origin_col": origin_col_f, "polar_ellipse_params": tuple(ellipse_params) if ellipse_params is not None else None, } ) @@ -233,5 +667,6 @@ def dataset4dstem_polar_transform( units=units, signal_units=signal_units if signal_units is not None else self.signal_units, metadata=metadata, + origin_array=origins, _token=Polar4dstem._token, ) diff --git a/src/quantem/diffraction/__init__.py b/src/quantem/diffraction/__init__.py index 2a79312b..eca03dc5 100644 --- a/src/quantem/diffraction/__init__.py +++ b/src/quantem/diffraction/__init__.py @@ -1,3 +1,3 @@ -from quantem.diffraction.polar import RDF as RDF +from quantem.diffraction.polar import PairDistributionFunction as PairDistributionFunction from quantem.diffraction.strain_autocorrelation import StrainMapAutocorrelation as StrainMapAutocorrelation from quantem.diffraction.maped import MAPED as MAPED diff --git a/src/quantem/diffraction/polar.py b/src/quantem/diffraction/polar.py index 7e87eff3..0507ef83 100644 --- a/src/quantem/diffraction/polar.py +++ b/src/quantem/diffraction/polar.py @@ -1,28 +1,44 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, List, Union +from collections.abc import Iterable +from typing import Any, Literal import matplotlib.pyplot as plt import numpy as np +import torch +import torch.nn.functional as F from numpy.typing import NDArray from quantem.core.datastructures.dataset2d import Dataset2d from quantem.core.datastructures.dataset3d import Dataset3d from quantem.core.datastructures.dataset4dstem import Dataset4dstem -from quantem.core.datastructures.polar4dstem import Polar4dstem +from quantem.core.datastructures.polar4dstem import ( + Polar4dstem, + auto_origin_id, + dataset4dstem_polar_transform, +) from quantem.core.io.serialize import AutoSerialize from quantem.core.utils.validators import ensure_valid_array +# TODO: subpixel origin finding (auto_origin_id currently uses integer pixel search) +# TODO: elliptical distortion correction in origin finding +# TODO: beamstop mask support (mask diffraction-space pixels before azimuthal averaging) -class RDF(AutoSerialize): + +class PairDistributionFunction(AutoSerialize): """ - Radial distribution / fluctuation electron microscopy analysis helper. + Pair distribution function (PDF) utilities for diffraction / 4D-STEM data. This class wraps a 4D-STEM (or 2D diffraction) dataset and stores a polar-transformed representation as a Polar4dstem instance in `self.polar`. - Analysis methods (radial statistics, PDF, FEM, clustering, etc.) are - provided as stubs for now and will be implemented in future revisions. + The PDF pipeline provides methods to compute: + + - azimuthal integration to obtain I(k) + - background fitting using a parametric model in k^2 / k^4 + - formation of F(k) and a windowed sine transform to obtain G(r) + - optional density estimation and origin correction (Yoshimoto & Omote-style iteration) + - basic plotting helpers for I(k), background, F(k), G(r), and g(r) + Some analysis methods (FEM, clustering, etc.) will be implemented in future revisions. """ _token = object() @@ -31,31 +47,32 @@ def __init__( self, polar: Polar4dstem, input_data: Any | None = None, + device: str = "cpu", _token: object | None = None, ): if _token is not self._token: raise RuntimeError( - "Use RadialDistributionFunction.from_data() to instantiate this class." + "Direct instantiation of PairDistributionFunction is not allowed. " + "Use PairDistributionFunction.from_data() to instantiate this class." ) super().__init__() self.polar = polar self.input_data = input_data + self.device = device - # Placeholders for analysis results (to be populated by future methods) - self.radial_mean: NDArray | None = None - self.radial_var: NDArray | None = None - self.radial_var_norm: NDArray | None = None - - self.pdf_r: NDArray | None = None - self.pdf_reduced: NDArray | None = None - self.pdf: NDArray | None = None - - self.Sk: NDArray | None = None - self.fk: NDArray | None = None - self.bg: NDArray | None = None - self.offset: float | None = None - self.Sk_mask: NDArray | None = None + self._r: torch.Tensor | None = None + self._reduced_pdf: torch.Tensor | None = None + self._pdf: torch.Tensor | None = None + self.Ik: torch.Tensor | None = None + self.Sk: torch.Tensor | None = None + self.Fk: torch.Tensor | None = None + self.bg: torch.Tensor | None = None + self.f: torch.Tensor | None = None + self.Fk_mask: torch.Tensor | None = None + self.Fk_damped: torch.Tensor | None = None + self.reduced_pdf_damped: torch.Tensor | None = None + self.rho0: float | None = None # ------------------------------------------------------------------ # Constructors @@ -63,8 +80,9 @@ def __init__( @classmethod def from_data( cls, - data: Union[NDArray, Dataset2d, Dataset3d, Dataset4dstem, Polar4dstem], + data: NDArray | Dataset2d | Dataset3d | Dataset4dstem | Polar4dstem, *, + find_origin: bool = True, origin_row: float | None = None, origin_col: float | None = None, ellipse_params: tuple[float, float, float] | None = None, @@ -73,9 +91,11 @@ def from_data( radial_max: float | None = None, radial_step: float = 1.0, two_fold_rotation_symmetry: bool = False, - ) -> "RadialDistributionFunction": + device: str = "cpu", + ): """ - Create a RadialDistributionFunction object from various input types. + -> "PairDistributionFunction" + Create a PairDistributionFunction object from various input types. Parameters ---------- @@ -86,47 +106,39 @@ def from_data( - Dataset2d - Dataset4dstem - Polar4dstem + + If a :class:`Polar4dstem` is provided, it is used directly and no origin finding + or polar transform is performed. + find_origin + If True, finds the origin for each scan position by calling + :meth:`find_origin`. If False, `origin_row`/`origin_col` are used (or default + to the image center). origin_row, origin_col - Diffraction-space origin (in pixels). If None, defaults to the - central pixel of the diffraction pattern. + Diffraction-space origin (in pixels), used only if `find_origin=False`. If None, + defaults to the central pixel of the diffraction pattern. Other parameters Passed through to Dataset4dstem.polar_transform when needed. """ # Polar input: use directly if isinstance(data, Polar4dstem): polar = data - return cls(polar=polar, input_data=data, _token=cls._token) - - # Dataset4dstem input: polar-transform it - if isinstance(data, Dataset4dstem): - scan_y, scan_x, ny, nx = data.array.shape - if origin_row is None: - origin_row = (ny - 1) / 2.0 - if origin_col is None: - origin_col = (nx - 1) / 2.0 - - polar = data.polar_transform( - origin_row=origin_row, - origin_col=origin_col, - ellipse_params=ellipse_params, - num_annular_bins=num_annular_bins, - radial_min=radial_min, - radial_max=radial_max, - radial_step=radial_step, - two_fold_rotation_symmetry=two_fold_rotation_symmetry, - ) - return cls(polar=polar, input_data=data, _token=cls._token) + return cls(polar=polar, input_data=data, device=device, _token=cls._token) - # Dataset2d input: wrap as a trivial 4D-STEM (1x1 scan) then polar-transform + # Dataset2d input: wrap as a trivial 4D-STEM (1x1 scan) and fall through if isinstance(data, Dataset2d): arr2d = data.array if arr2d.ndim != 2: - raise ValueError("Dataset2d for RDF must be 2D.") + raise ValueError( + f"Found array with shape: {arr2d.shape}. " + "Dataset2d for PairDistributionFunction must be 2D." + ) arr4 = arr2d[None, None, ...] # (1, 1, ky, kx) - ds4 = Dataset4dstem.from_array( + data = Dataset4dstem.from_array( array=arr4, - name=f"{data.name}_as4dstem" if getattr(data, "name", None) else "rdf_4dstem_from_2d", + name=f"{data.name}_as4dstem" + if getattr(data, "name", None) + else "rdf_4dstem_from_2d", origin=np.concatenate( [np.zeros(2, dtype=float), np.asarray(data.origin, dtype=float)] ), @@ -136,36 +148,58 @@ def from_data( units=["pixels", "pixels"] + list(data.units), signal_units=data.signal_units, ) - ny, nx = ds4.array.shape[-2:] - if origin_row is None: - origin_row = (ny - 1) / 2.0 - if origin_col is None: - origin_col = (nx - 1) / 2.0 - polar = ds4.polar_transform( - origin_row=origin_row, - origin_col=origin_col, + # Dataset4dstem input: polar-transform it + if isinstance(data, Dataset4dstem): + scan_y, scan_x, ny, nx = data.array.shape + if find_origin: + origin_array = auto_origin_id( + data, + ellipse_params=ellipse_params, + num_annular_bins=num_annular_bins, + radial_min=radial_min, + radial_max=radial_max, + radial_step=radial_step, + two_fold_rotation_symmetry=two_fold_rotation_symmetry, + device=device, + ) + else: + if origin_row is None: + origin_row = (ny - 1) / 2.0 + if origin_col is None: + origin_col = (nx - 1) / 2.0 + origin_array = np.zeros((scan_y, scan_x, 2), dtype=float) + origin_array[..., 0] = origin_row + origin_array[..., 1] = origin_col + + polar = dataset4dstem_polar_transform( + data, + origin_array=origin_array, ellipse_params=ellipse_params, num_annular_bins=num_annular_bins, radial_min=radial_min, radial_max=radial_max, radial_step=radial_step, two_fold_rotation_symmetry=two_fold_rotation_symmetry, + device=device, ) - return cls(polar=polar, input_data=data, _token=cls._token) + return cls(polar=polar, input_data=data, device=device, _token=cls._token) # Dataset3d input: not yet specified how to interpret if isinstance(data, Dataset3d): raise NotImplementedError( - "RadialDistributionFunction.from_data does not yet support Dataset3d inputs." + "PairDistributionFunction.from_data does not yet support Dataset3d inputs. " + "Please provide a 4D-STEM dataset or a 2D diffraction pattern." ) # Numpy array input arr = ensure_valid_array(data) if arr.ndim == 2: ds2 = Dataset2d.from_array(arr, name="rdf_input_2d") + return cls.from_data( ds2, + find_origin=find_origin, origin_row=origin_row, origin_col=origin_col, ellipse_params=ellipse_params, @@ -174,11 +208,13 @@ def from_data( radial_max=radial_max, radial_step=radial_step, two_fold_rotation_symmetry=two_fold_rotation_symmetry, + device=device, ) elif arr.ndim == 4: ds4 = Dataset4dstem.from_array(arr, name="rdf_input_4dstem") return cls.from_data( ds4, + find_origin=find_origin, origin_row=origin_row, origin_col=origin_col, ellipse_params=ellipse_params, @@ -187,10 +223,12 @@ def from_data( radial_max=radial_max, radial_step=radial_step, two_fold_rotation_symmetry=two_fold_rotation_symmetry, + device=device, ) else: raise ValueError( - "RadialDistributionFunction.from_data only supports 2D or 4D arrays." + f"Found array with shape: {arr.shape}. " + "PairDistributionFunction.from_data only supports 2D or 4D arrays." ) # ------------------------------------------------------------------ @@ -204,123 +242,1146 @@ def qq(self) -> Any: """ # Polar4dstem dims: (scan_y, scan_x, phi, r) # radial axis is 3 - return self.polar.coords_units(3) + # origin[3] is the physical q-value at bin 0 (radial_min * pixel_size), + # sampling[3] is the physical step per bin (radial_step * pixel_size). + n = self.polar.shape[3] + origin_r = float(np.asarray(self.polar.origin)[3]) + sampling_r = float(np.asarray(self.polar.sampling)[3]) + return np.arange(n, dtype=float) * sampling_r + origin_r @property - def radial_bins(self) -> Any: + def r(self) -> NDArray | None: + """Real-space radial grid as a numpy array.""" + if self._r is None: + return None + return self._to_numpy(self._r) + + @property + def reduced_pdf(self) -> NDArray | None: + """Reduced pair distribution function G(r) as a numpy array.""" + if self._reduced_pdf is None: + return None + return self._to_numpy(self._reduced_pdf) + + @property + def pdf(self) -> NDArray | None: + """Pair distribution function g(r) as a numpy array.""" + if self._pdf is None: + return None + return self._to_numpy(self._pdf) + + # ------------------------------------------------------------------ + # Helper functions + # ------------------------------------------------------------------ + def _get_mask_bool(self, mask_realspace): """ - Radial bin centers in pixel units (convenience alias). + Normalize a real-space mask specification to a boolean (rx, ry) mask. + + Returns + ------- + mask_bool : np.ndarray or None + Boolean mask of shape (rx, ry), or None if `mask_realspace` is None. """ - return self.polar.coords(3) + mask_bool = None + if mask_realspace is not None: + rx, ry = self.polar.array.shape[:2] + mask_realspace = np.asarray(mask_realspace) + + if mask_realspace.dtype == bool and mask_realspace.shape == (rx, ry): + mask_bool = mask_realspace + else: + raise ValueError( + f"Got shape {mask_realspace.shape}. " + f"mask_realspace must be boolean array of shape ({rx}, {ry})." + ) + return mask_bool # ------------------------------------------------------------------ - # Analysis method stubs (py4DSTEM-style API) + # Torch conversion utilities # ------------------------------------------------------------------ - def calculate_radial_statistics( + + def _to_torch(self, arr: NDArray) -> torch.Tensor: + return torch.from_numpy(arr.astype(np.float32)).to(device=self.device) + + def _to_numpy(self, tensor: torch.Tensor) -> NDArray: + return tensor.detach().cpu().numpy() + + @staticmethod + def _gaussian_kernel_1d( + sigma: float, device: str = "cpu", num_sigmas: float = 3.0 + ) -> torch.Tensor: + """Create 1D Gaussian kernel for torch convolution.""" + radius = int(np.ceil(num_sigmas * sigma)) + support = torch.arange(-radius, radius + 1, dtype=torch.float32, device=device) + kernel = torch.exp(-0.5 * (support / sigma) ** 2) + kernel = kernel / kernel.sum() + return kernel + + def _gaussian_filter1d_torch( self, - mask_realspace: NDArray | None = None, - plot_results_mean: bool = False, - plot_results_var: bool = False, - figsize: tuple[float, float] = (8, 4), - returnval: bool = False, - returnfig: bool = False, - progress_bar: bool = True, - ): + Fk: torch.Tensor, + sigma: float, + mode: str = "nearest", + ) -> torch.Tensor: """ - Stub for radial statistics (FEM-style) calculation on the polar data. + Apply 1D Gaussian filter, replaces scipy.ndimage.gaussian_filter1d. + """ + kernel = self._gaussian_kernel_1d(sigma, device=self.device) + padding = len(kernel) // 2 + x = Fk.unsqueeze(0).unsqueeze(0) # reshape to (batch, channels, length) + kernel_w = kernel.view(1, 1, -1) + if mode == "nearest": + x = F.pad(x, (padding, padding), mode="replicate") + result = F.conv1d(x, kernel_w) + else: + result = F.conv1d(x, kernel_w, padding=padding) + return result.squeeze(0).squeeze(0) # reshape to (length) - Intended to compute radial mean, variance, and normalized variance - from self.polar. Not implemented yet. + @staticmethod + def _scattering_model_torch( + k2: torch.Tensor, + c: torch.Tensor, + i0: torch.Tensor, + s0: torch.Tensor, + i1: torch.Tensor, + s1: torch.Tensor, + ) -> torch.Tensor: + """Torch version of the scattering model.""" + # Add small epsilon to denominators to prevent division by zero during backprop + # while still allowing s0/s1 to vary freely + eps = 1e-10 + exp1 = torch.clamp(k2 / (-2.0 * (s0**2 + eps)), min=-100, max=0) + exp2 = torch.clamp((k2**2) / (-2.0 * (s1**4 + eps)), min=-100, max=0) + # scattering model is monotonic, as is physically expected for backgrounds scattering + return c + i0 * torch.exp(exp1) + i1 * torch.exp(exp2) + + def _compute_fit_weights(self, k: torch.Tensor, kmin: float, kmax: float) -> torch.Tensor: + """ + Compute weighting tensor for background fitting. + Weights downweight low-k region (using sin² taper) and emphasize high-k values. """ - raise NotImplementedError("calculate_radial_statistics is not implemented yet.") + dk = k[1] - k[0] + k_width = kmax - kmin - def plot_radial_mean( + # sin² taper for low-k suppression + mask_low = torch.sin(torch.clamp((k - kmin) / k_width, 0.0, 1.0) * (torch.pi / 2.0)) ** 2 + # high weight where mask_low is small + # later used to divide, so large weights mean small contribution + weights = torch.where( + mask_low > 1e-4, + 1.0 / mask_low, + torch.tensor(1e6, device=self.device, dtype=k.dtype), + ) + # emphasize high-k values + weights = weights * (k[-1] - 0.9 * k + dk) + return weights + + def _closure(self, optimizer, theta, k2, Ik_norm, weights): + """match scipy curve_fit behavior""" + optimizer.zero_grad() + # Map from unconstrained to constrained (positive) space via softplus + c = F.softplus(theta[0]) + i0 = F.softplus(theta[1]) + s0 = F.softplus(theta[2]) + i1 = F.softplus(theta[3]) + s1 = F.softplus(theta[4]) + + pred = self._scattering_model_torch(k2, c, i0, s0, i1, s1) + residuals = (pred - Ik_norm) ** 2 + loss = (residuals / (weights**2)).sum() + loss.backward() + return loss + + def _frequency_filtering( self, - log_x: bool = False, - log_y: bool = False, - figsize: tuple[float, float] = (8, 4), - returnfig: bool = False, - ): + Fk: torch.Tensor, + k_lowpass: float | None, + k_highpass: float | None, + dk: torch.Tensor, + ) -> torch.Tensor: + """Band pass filtering using torch""" + if ( + k_lowpass is not None + and k_lowpass > 0.0 + and k_highpass is not None + and k_highpass > 0.0 + ): + if k_highpass > k_lowpass: + raise ValueError( + "k_highpass is greater than k_lowpass." + "Gaussian band-pass filtering requires k_highpass < k_lowpass." + ) + Fk_low = self._gaussian_filter1d_torch(Fk, sigma=k_lowpass / dk.item(), mode="nearest") + Fk_high = self._gaussian_filter1d_torch( + Fk, sigma=k_highpass / dk.item(), mode="nearest" + ) + Fk = Fk_high - Fk_low + elif k_lowpass is not None and k_lowpass > 0.0: + Fk = self._gaussian_filter1d_torch(Fk, sigma=k_lowpass / dk.item(), mode="nearest") + elif k_highpass is not None and k_highpass > 0.0: + Fk_high = self._gaussian_filter1d_torch( + Fk, sigma=k_highpass / dk.item(), mode="nearest" + ) + Fk = Fk - Fk_high + return Fk + + def _lorch_window(self, k: torch.Tensor, kmin: float, kmax: float) -> torch.Tensor: """ - Stub for plotting radial mean intensity vs scattering vector. + Construct a combined low-q taper and high-q Lorch window. + + The returned window is: + - zero outside [kmin, kmax] + - smoothly rises from 0->1 near kmin using a sin^2 ramp over 10% of the band + - applies a Lorch-style sinc factor over the full in-band region: + sin(pi * k/kmax) / (pi * k/kmax) """ - raise NotImplementedError("plot_radial_mean is not implemented yet.") + # low q taper + edge_frac_low = 0.1 # 10% of range at low-q + edge_width_low = edge_frac_low * (kmax - kmin) + low = (k >= kmin) & (k < kmin + edge_width_low) + t = (k - kmin) / edge_width_low + wk = torch.ones_like(k) + wk = torch.where(low, torch.sin(0.5 * torch.pi * t) ** 2, wk) + wk = torch.where(k < kmin, torch.zeros_like(wk), wk) + wk = torch.where(k > kmax, torch.zeros_like(wk), wk) + + # High q taper with Lorch window: w(k) = sin(pi*k/kmax)/(pi*k/kmax) + x = k / kmax + inband = (k >= kmin) & (k <= kmax) + # sinc function: sin(pi*x)/(pi*x) with limit 1 at x=0 + sinc_val = torch.where( + x == 0, + torch.ones_like(x), + torch.sin(torch.pi * x) / (torch.pi * x), + ) + lorch = torch.where(inband, sinc_val, torch.zeros_like(k)) + wk = wk * lorch + return wk - def plot_radial_var_norm( + def _compute_alpha_beta( self, - figsize: tuple[float, float] = (8, 4), - returnfig: bool = False, - ): + Q2d: torch.Tensor, + r2d: torch.Tensor, + G_beta: torch.Tensor, + r_1d: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute Yoshimoto-Omote alpha(Q) and beta(Q) integrals used for density estimation. + """ + Qsafe = torch.where( + Q2d == 0.0, + torch.tensor(1e-12, device=self.device, dtype=torch.float32), + Q2d, + ) + alpha_int = -4 * torch.pi * r2d * torch.sin(Qsafe * r2d) / Qsafe + beta_int = G_beta.unsqueeze(0) * torch.sin(Qsafe * r2d) / Qsafe + alpha = torch.trapezoid(alpha_int, x=r_1d, dim=1) + beta = torch.trapezoid(beta_int, x=r_1d, dim=1) + return alpha, beta + + # ------------------------------------------------------------------ + # Analysis method stubs + # ------------------------------------------------------------------ + + # TODO: add beamstop mask support (mask diffraction-space pixels before + # azimuthal averaging, e.g. to exclude a beam stop shadow) + + def calculate_radial_mean( + self, + mask_realspace: NDArray | None = None, + returnval: bool = False, + ) -> torch.Tensor | None: + """ + Calculate the radial mean intensity from the Polar4dSTEM dataset. + + The polar array is assumed to have shape (scan_y, scan_x, phi, k). + This method computes, for each scan position, the mean over the azimuthal + axis (phi), then averages across scan positions to produce a single 1D + radial curve. This result is stored in ``self.Ik``. + + If a real-space mask is provided, only the selected scan positions are + used in the scan-position average. The computation streams chunks through torch to keep peak + memory low. + + Parameters + ---------- + mask_realspace : NDArray or None, optional + Boolean mask in real space used to select probe positions. + If ``None``, all probe positions are used. + Must have shape (scan_y, scan_x) where True means "include". + returnval : bool, optional + If True, return the computed 1D radial mean tensor. + + Returns + ------- + radial_mean : torch.Tensor or None + If `returnval=True`, returns the 1D radial mean intensity (Nk,). + Otherwise returns None. + """ + polar_np = self.polar.array # shape: (scan_y, scan_x, phi, k) + scan_y, scan_x, n_phi, n_k = polar_np.shape + intensity_sum = torch.zeros(n_k, device=self.device, dtype=torch.float64) + n_valid = 0 + chunk_y = 16 # number of scan_y to process at a time + for y0 in range(0, scan_y, chunk_y): + y1 = min(y0 + chunk_y, scan_y) + raw = polar_np[y0:y1] + chunk = torch.from_numpy(np.ascontiguousarray(raw)).to(self.device) + # mean over phi first -> (chunk, scan_x, k) + radial_mean = chunk.mean(dim=2) + if mask_realspace is not None: + mask_chunk = torch.from_numpy(mask_realspace[y0:y1]).to(self.device) + n_chunk = int(mask_chunk.sum()) + if n_chunk == 0: + continue + # sum unmasked intensities in chunk and count for normalization later + intensity_sum += radial_mean[mask_chunk].sum(dim=0) + n_valid += n_chunk + else: + # sum all intensities in chunk and count for normalization later + intensity_sum += radial_mean.sum(dim=(0, 1)) + n_valid += (y1 - y0) * scan_x + if n_valid == 0: + raise ValueError( + "No valid scan positions selected. The real-space mask is " + "all False or the dataset is empty." + ) + self.Ik = (intensity_sum / n_valid).float() + + if returnval: + return self.Ik + else: + return None + + def fit_bg( + self, + Ik: torch.Tensor, + kmin: float, + kmax: float, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Stub for plotting normalized radial variance vs scattering vector. + Fit a smooth background B(k) to a radial intensity curve I(k) using + PyTorch LBFGS optimizer, with weighting that downweights the low-k + region and emphasizes higher k. + + The fitted function uses the following form (adopted from py4dstem): + B(k) = c + + i0 * exp(-k^2 / (2 s0^2)) + + i1 * exp(-k^4 / (2 s1^4)) + + Parameters + ---------- + Ik + 1D radial intensity tensor (Nk,). Produced by + :meth:`calculate_radial_mean`. + kmin, kmax + k-range (in the same units as the internally constructed `k` grid) + used to build the low-k weighting mask. + + Returns + ------- + bg : torch.Tensor + Fitted background curve B(k), shape (Nk,). + f : torch.Tensor + Background minus the constant offset, f(k) = B(k) - c, or functionally + similar to ^2(k) """ - raise NotImplementedError("plot_radial_var_norm is not implemented yet.") + k = self._to_torch(np.asarray(self.qq)) + k2 = k**2 + + # normalize intensity + int_mean = Ik.mean() + Ik_norm = Ik / int_mean + # initial guesses + const_bg = float(Ik_norm.min()) + int0 = float(Ik_norm.median()) - const_bg + sigma0 = float(k.mean()) + # ensure positive values + const_bg = max(const_bg, 1e-6) + int0 = max(int0, 1e-6) + sigma0 = max(sigma0, 1e-6) + + init_vals = torch.tensor( + [const_bg, int0, sigma0, int0, sigma0], + device=self.device, + dtype=torch.float32, + ) + # Map to unconstrained space via inverse softplus: x = y + log(1 - exp(-y)) + # For numerical stability, clamp init_vals away from zero + # final values must be positive for a physical model of background scattering + init_vals = torch.clamp(init_vals, min=1e-6) + theta = init_vals + torch.log(-torch.expm1(-init_vals)) + theta = theta.clone().detach().requires_grad_(True) + optimizer = torch.optim.LBFGS( + [theta], + lr=1.0, + max_iter=20, + tolerance_grad=1e-7, + tolerance_change=1e-9, + line_search_fn="strong_wolfe", + ) + + # fitting weights (high-k range is emphasized for better bg estimation) + # this monotonic model means we don't need parameterized scattering factors + weights = self._compute_fit_weights(k, kmin, kmax) - def calculate_pair_dist_function( + prev_loss = torch.tensor(float("inf")) + max_outer_iter = 100 + tol = 1e-8 + for step in range(max_outer_iter): + loss = optimizer.step(lambda: self._closure(optimizer, theta, k2, Ik_norm, weights)) + if torch.abs(prev_loss - loss) < tol: + break + prev_loss = loss + + # final params (ensure positivity via softplus) + with torch.no_grad(): + c = F.softplus(theta[0]) + i0 = F.softplus(theta[1]) + s0 = F.softplus(theta[2]) + i1 = F.softplus(theta[3]) + s1 = F.softplus(theta[4]) + # undo normalization + c_scaled = c * int_mean + i0_scaled = i0 * int_mean + i1_scaled = i1 * int_mean + # compute bg and the average scattering factor f(k) + bg = self._scattering_model_torch(k2, c_scaled, i0_scaled, s0, i1_scaled, s1) + f = bg - c_scaled + self.bg = bg + self.f = f + return bg, f + + def calculate_Gr( self, - k_min: float = 0.05, - k_max: float | None = None, - k_width: float = 0.25, + k_min_fit: float | None = None, + k_max_fit: float | None = None, + k_min_window: float | None = None, + k_max_window: float | None = None, k_lowpass: float | None = None, k_highpass: float | None = None, r_min: float = 0.0, r_max: float = 20.0, r_step: float = 0.02, - damp_origin_fluctuations: bool = True, - enforce_positivity: bool = True, + mask_realspace: NDArray | None = None, + damp_origin_oscillations: bool = False, density: float | None = None, - plot_background_fits: bool = False, - plot_sf_estimate: bool = False, - plot_reduced_pdf: bool = True, - plot_pdf: bool = False, - figsize: tuple[float, float] = (8, 4), - maxfev: int | None = None, + r_cut: float = 0.8, returnval: bool = False, - returnfig: bool = False, + ) -> list[NDArray] | None: + """ + Calculate the reduced pair distribution function G(r) from a 4D-STEM dataset. + + This routine: + * Computes the radial mean intensity I(k) from self.polar (optionally + restricted to a real-space mask). + * Fits a smooth background B(k) and associated f(k) using :meth:`fit_bg`. + * Constructs the reduced structure factor F(k) with optional low/highpass filtering. + * Applies a window in k (low-k sin^2 ramp x Lorch high-k taper). + * Computes the reduced PDF using a discrete sine transform: + G(r) = sum_k sin(2*pi*k*r) * F_windowed(k) + + If ``damp_origin_oscillations=True``, :meth:`estimate_density` is called + and the corrected F(k)/G(r) are stored as ``self.Fk_damped`` and + ``self.reduced_pdf_damped``. The estimated density is cached in + ``self.rho0`` so that a subsequent :meth:`calculate_gr` call can reuse it. + + Stored attributes: + * self.Ik, self.bg, self.Fk, self.Fk_masked + * self.Sk, self.r, self.reduced_pdf + * self.rho0, self.Fk_damped, self.reduced_pdf_damped (if damping) + + Parameters + ---------- + k_min_fit : float, optional + Minimum k (A^-1) for the background fit. + k_max_fit : float or None, optional + Maximum k (A^-1) for the background fit. + k_min_window : float or None, optional + Minimum k (A^-1) for the structure-factor Lorch window. + If None, falls back to ``k_min_fit``. + k_max_window : float or None, optional + Maximum k (A^-1) for the structure-factor Lorch window. + If None, falls back to ``k_max_fit``. + k_lowpass : float or None, optional + Low-pass Gaussian filter sigma in k-space. + k_highpass : float or None, optional + High-pass Gaussian filter sigma in k-space. + r_min : float, optional + Minimum r (A) for the real-space grid. + r_max : float, optional + Maximum r (A) for the real-space grid. + r_step : float, optional + Step size in r (A) for the real-space grid. + mask_realspace : NDArray or None, optional + Boolean real-space mask selecting probe positions. + damp_origin_oscillations : bool, optional + If True, run :meth:`estimate_density` and store corrected F(k)/G(r). + density : float or None, optional + Known number density (atoms/A^3). If provided together with + ``damp_origin_oscillations=True``, the S(k)/G(r) correction uses + this value instead of estimating it. + r_cut : float, optional + Minimum radial distance (A) for peak search in density estimation. + Forwarded to :meth:`estimate_density`. + returnval : bool, optional + If True, return ``[r, G(r)]`` as numpy arrays. + + Returns + ------- + list[np.ndarray] or None + """ + # clear results from any previous run so stale state doesn't leak + self.Fk_damped = None + self.reduced_pdf_damped = None + self.rho0 = None + # this is missing a 2pi term that we add back during the pdf calc later + k_np = np.asarray(self.qq) + k = self._to_torch(k_np) + dk = k[1] - k[0] + # small epsilon to avoid division by very small k values + k_safe = torch.clamp(k, min=1e-10) + self.kmax_fit = k_max_fit if k_max_fit is not None else float(k.max()) + self.kmin_fit = k_min_fit if k_min_fit is not None else float(k.min()) + # window range defaults to bg-fit range when not specified + self.kmin_window = k_min_window if k_min_window is not None else self.kmin_fit + self.kmax_window = k_max_window if k_max_window is not None else self.kmax_fit + + mask_bool = self._get_mask_bool(mask_realspace) + # reuse existing radial mean if already computed + if self.Ik is not None: + Ik = self.Ik + else: + Ik = self.calculate_radial_mean(mask_realspace=mask_bool, returnval=True) + # reuse existing background fit if already computed + if self.bg is not None and self.f is not None: + bg, f = self.bg, self.f + else: + bg, f = self.fit_bg(Ik, self.kmin_fit, self.kmax_fit) + # prevent division by near-zero values which cause NaNs at high k + f_safe = torch.clamp(f, min=1e-10 * f.max()) + + # below is the standard definition of F(k) used in PDF analysis, except for missing 2pi factor + Fk = (Ik - bg) * k_safe / f_safe + # apply optional frequency filtering for noise reduction + Fk = self._frequency_filtering(Fk, k_lowpass, k_highpass, dk) + # Compute Sk from Fk BEFORE applying the 2pi scaling, + # so that estimate_density corrections are on the same scale + self.Sk = torch.ones_like(k) + mask = k > 0 + self.Sk = torch.where(mask, 1.0 + (Fk / k_safe), self.Sk) + # apply that missing 2pi factor + Fk = Fk * 2 * torch.pi + # damp edges with lorch window + wk = self._lorch_window(k, self.kmin_window, self.kmax_window) + Fk_win = Fk * wk + + r = torch.arange(r_min, r_max, r_step, device=self.device, dtype=torch.float32) + ka, ra = torch.meshgrid(k, r, indexing="ij") + # compute reduced PDF using discrete sine transform + reduced_pdf = ( + (2 / torch.pi) + * dk + * 2 + * torch.pi + * torch.sum( + torch.sin(2 * torch.pi * ra * ka) * Fk_win[:, None], + dim=0, + ) + ) + reduced_pdf[0] = 0 # physically must be at 0 when r = 0 + + self.Ik = Ik + self.bg = bg + self.Fk = Fk + self.Fk_masked = Fk_win + self._r = r + self._reduced_pdf = reduced_pdf + + # optionally damped unphysical oscillations near the origin by iteratively estimating density and correcting F(k) + if damp_origin_oscillations: + density_est = self.estimate_density( + density=density, + r_cut=r_cut, + max_iter=20, + tol_percent=1e-1, + ) + self.rho0 = density_est[0] + self.Fk_damped = density_est[1] + self.reduced_pdf_damped = density_est[2] + + if returnval: + Gr = ( + self.reduced_pdf_damped + if self.reduced_pdf_damped is not None + else self._reduced_pdf + ) + return [self._to_numpy(self._r), self._to_numpy(Gr)] + return None + + def calculate_gr( + self, + density: float | None = None, + r_cut: float = 0.8, + set_pdf_positive: bool = False, + returnval: bool = False, + ) -> list[NDArray] | None: + """ + Calculate the pair distribution function g(r) from G(r). + + Requires :meth:`calculate_Gr` to have been run first. The density + rho0 is determined by (in priority order): + + 1. The ``density`` argument, if provided. + 2. ``self.rho0``, if already cached from a prior :meth:`estimate_density` call + (e.g. via ``calculate_Gr(damp_origin_oscillations=True)``). + 3. A fresh call to :meth:`estimate_density` (result cached in ``self.rho0``). + + The G(r) used is ``self.reduced_pdf_damped`` if it exists (i.e. the user + chose damping in :meth:`calculate_Gr`), otherwise ``self.reduced_pdf``. + + Parameters + ---------- + density : float or None, optional + Number density (atoms/A^3). If None, uses cached or estimated value. + r_cut : float, optional + Minimum radial distance (A) for peak search in density estimation. + Only used when density must be estimated. Forwarded to + :meth:`estimate_density`. + set_pdf_positive : bool, optional + If True, clamp negative g(r) values to 0. + returnval : bool, optional + If True, return ``[r, g(r)]`` as numpy arrays. + + Returns + ------- + list[np.ndarray] or None + """ + if self._reduced_pdf is None or self._r is None: + raise RuntimeError( + "Reduced PDF not computed." + "Run PairDistributionFunction.calculate_Gr() before calculate_gr()." + ) + + # Determine density + if density is not None: + rho0 = density + elif self.rho0 is not None: + rho0 = self.rho0 + print(f" Using estimated rho0 = {rho0:.6f} atoms/A^3", flush=True) + else: + # the oscillation correction simultaneously produces a density estimate + # if the user didn't run damping in calculate_Gr, we can still run the density estimation without using the corrected Fk/G(r) + density_est = self.estimate_density( + r_cut=r_cut, + max_iter=20, + tol_percent=1e-1, + ) + self.rho0 = density_est[0] + rho0 = self.rho0 + print(f" Estimated rho0 = {rho0:.6f} atoms/A^3", flush=True) + + # Use damped G(r) if the user opted into damping, otherwise undamped + Gr = self.reduced_pdf_damped if self.reduced_pdf_damped is not None else self._reduced_pdf + Gr = Gr.clone() + + r = self._r + mask = r > 0 + pdf = torch.ones_like(Gr) + # the formula for g(r) from G(r) is: g(r) = 1 + G(r) / (4 * pi * r * rho0) + pdf = torch.where(mask, 1 + Gr / (4 * torch.pi * r * rho0), torch.zeros_like(pdf)) + if set_pdf_positive: # negative values are unphysical + pdf = torch.maximum(pdf, torch.zeros_like(pdf)) + + self._pdf = pdf + if returnval: + return [self._to_numpy(self._r), self._to_numpy(self._pdf)] + return None + + def estimate_density( + self, + density: float | None = None, + r_cut: float = 0.8, + max_iter: int = 40, + tol_percent: float = 1e-4, + ) -> tuple[float, torch.Tensor, torch.Tensor]: + """ + Estimate number density rho0 (atoms/A^3) and compute a corrected G(r). + + This method implements an iterative Q-space density estimation by + Yoshimoto & Omote (2022). It uses the structure factor `self.Sk` and + the reduced PDF `self.reduced_pdf` to iteratively update rho0 and a + corrected S(k) so that the implied G(r) is more physically consistent + at low r. + + If ``density`` is provided, the given value is used as a fixed rho0 + for the S(k)/G(r) correction instead of estimating it iteratively. + + This method requires that :meth:`calculate_Gr` has already been run, + because it depends on `self.Sk`, `self.reduced_pdf`, `self.r`, + and the k-window bounds (`self.kmin_fit`, `self.kmin_window`, + `self.kmax_window`). + + Parameters + ---------- + density : float or None, optional + Known number density (atoms/A^3). If provided, used as a fixed + rho0 — the iterative estimation is skipped and only the S(k)/G(r) + correction is performed. + r_cut : float, optional + Minimum radial distance (A) for the peak search used to determine + the correction interval. Peaks below this distance are ignored. + max_iter : int, optional + Maximum number of Q-space iterations. + tol_percent : float, optional + Convergence threshold on the relative change in rho0 (in %), + as defined in Eq. (12) of Yoshimoto & Omote (2022). + + Returns + ------- + rho0 : float + Number density (atoms/A^3), either provided or estimated. + Fk_win_damped : torch.Tensor + Windowed corrected reduced structure function used for the transform. + G_cor : torch.Tensor + Reduced PDF G(r) with dampened oscillations near origin. + """ + # we need the non-reduced structure factor (S(k) = 1 + F(k)/k) for the density estimation correction, + # so we compute it here from the Fk we already have + if self.Sk is None or self._reduced_pdf is None or self._r is None: + raise RuntimeError( + "This method depends on Sk, reduced_pdf, and r from calculate_Gr. " + "Run PairDistributionFunction.calculate_Gr() before estimate_density()." + ) + + k = self._to_torch(np.asarray(self.qq)) + dk = k[1] - k[0] + k_fit_mask = (k >= self.kmin_fit) & (k <= self.kmax_window) + k_fit = k[k_fit_mask] + ka, ra = torch.meshgrid(k, self._r, indexing="ij") + + # r_cut sets the minimum r for the peak search used to determine the correction interval + mask_search = self._r >= r_cut + r_search = self._r[mask_search] + G_search = self._reduced_pdf[mask_search] + # find tallest peak and first local minimum to the left of r_peak + ind_max = torch.argmax(G_search) + r_max = r_search[ind_max] + left = self._r < r_max + if not torch.any(left): + # If peak is immediately at cutoff, just use cutoff as rmin + rmin = r_cut + else: + r_left = self._r[left] + G_left = self._reduced_pdf[left] + mins_cond = (G_left[1:-1] < G_left[:-2]) & (G_left[1:-1] < G_left[2:]) + # fix indexing from slicing with +1 + mins_indices = torch.where(mins_cond)[0] + 1 + # minimum closest to the peak, else global min in left interval + if mins_indices.numel() > 0: + rmin = float(r_left[mins_indices[-1]]) + else: + rmin = float(r_left[torch.argmin(G_left)]) + # Restrict r to [0, rmin] for the correction + r_mask = (self._r >= 0.0) & (self._r <= rmin) + r_short = self._r[r_mask] + k_fit_scaled = k_fit * 2 * torch.pi + k2d_fit, r2d_fit = torch.meshgrid(k_fit_scaled, r_short, indexing="ij") + + # Iterative refinement of rho0 and S(k) + fixed_density = density is not None + rho0 = density if fixed_density else 0.0 + rho0_prev = None + Sk_cor = self.Sk.clone() + # calculate lorch function once bc it doesn't change during the iteration + wk = self._lorch_window(k, self.kmin_window, self.kmax_window) + # windowed G(r) for the iteration + Fk_win = k * (Sk_cor - 1.0) * wk * 2 * torch.pi + G_iter = ( + (2.0 / torch.pi) + * dk + * 2 + * torch.pi + * torch.sum(torch.sin(2 * torch.pi * ka * ra) * Fk_win[:, None], dim=0) + ) + G_iter[0] = 0.0 + G_beta = G_iter[r_mask] + beta_prev = None + for j in range(max_iter): + if j > 0: + G_beta = G_iter[r_mask] + # calculate alpha/beta for S(k) adjustment + # alpha and beta are the ideal and actual contributions to G(r) in the short-r range + # from the current S(k) and G(r) + alpha, beta = self._compute_alpha_beta(k2d_fit, r2d_fit, G_beta, r_short) + if not fixed_density: + rho0 = float(torch.sum(alpha * beta) / torch.sum(alpha**2)) + if rho0_prev is not None: + Rj = np.sqrt(((rho0_prev - rho0) ** 2) / (rho0**2)) * 100.0 + if Rj < tol_percent: + break + else: + # fixed density: converge on the S(k) correction magnitude + if beta_prev is not None: + delta = float(torch.max(torch.abs(beta - beta_prev))) + if delta < tol_percent * 1e-2: + break + beta_prev = beta.clone() + # Update S_cor(k) and G_cor + Sk_cor[k_fit_mask] = Sk_cor[k_fit_mask] - beta + rho0 * alpha + Fk_win = k * (Sk_cor - 1.0) * wk * 2 * torch.pi + G_iter = ( + (2.0 / torch.pi) + * dk + * 2 + * torch.pi + * torch.sum(torch.sin(2 * torch.pi * ka * ra) * Fk_win[:, None], dim=0) + ) + G_iter[0] = 0.0 + rho0_prev = rho0 + return rho0, Fk_win, G_iter + + # ------------------------------------------------------------------ + # Plotting functions + # ------------------------------------------------------------------ + + PlotName = Literal[ + "radial_mean", + "background_fits", + "reduced_sf", + "reduced_pdf", + "pdf", + "oscillation_damping", + ] + + def _apply_xrange( + self, + x: NDArray, + y: NDArray, + xmin: float | None, + xmax: float | None, + ) -> tuple[NDArray, NDArray]: + if xmin is None and xmax is None: + return x, y + xmin_eff = x.min() if xmin is None else xmin + xmax_eff = x.max() if xmax is None else xmax + if xmax_eff <= xmin_eff: + raise ValueError(f"xmax must be > xmin (got xmin={xmin_eff}, xmax={xmax_eff}).") + m = (x >= xmin_eff) & (x <= xmax_eff) + # avoid empty plots + if not np.any(m): + raise ValueError("Requested plot range contains no data.") + return x[m], y[m] + + def plot_pdf_results( + self, + which: Iterable[PlotName] = ("reduced_pdf",), + *, + qmin: float | None = None, + qmax: float | None = None, + rmin: float | None = None, + rmax: float | None = None, + figsize: tuple[float, float] = (6, 4), + returnfigs: bool = False, ): """ - Stub for pair distribution function (PDF) calculation from radial statistics. + Convenience plotting dispatcher. - Intended to estimate S(k), background, and transform to real-space g(r)/G(r). + Examples + -------- + pdfc.calculate_Gr(...) + pdfc.plot(["radial_mean", "background", "reduced_pdf"]) """ - raise NotImplementedError("calculate_pair_dist_function is not implemented yet.") + mapping = { + "radial_mean": self.plot_radial_mean, + "background_fits": self.plot_background_fits, + "reduced_sf": self.plot_reduced_sf, + "reduced_pdf": self.plot_reduced_pdf, + "pdf": self.plot_pdf, + "oscillation_damping": self.plot_oscillation_damping, + } + + figs = [] + for name in which: + if name not in mapping: + raise ValueError(f"Unknown plot '{name}'. Options: {tuple(mapping)}") + fig = mapping[name]( + qmin=qmin, qmax=qmax, rmin=rmin, rmax=rmax, figsize=figsize, returnfig=returnfigs + ) + if returnfigs: + figs.append(fig) + + return figs if returnfigs else None + + def plot_radial_mean( + self, + qmin: float | None = None, + qmax: float | None = None, + rmin: float | None = None, # accepted for dispatcher compatibility, unused + rmax: float | None = None, # accepted for dispatcher compatibility, unused + figsize: tuple[float, float] = (8, 4), + returnfig: bool = False, + ): + """ + Plotting radial mean intensity vs scattering vector. + """ + + if self.Ik is None: + raise RuntimeError( + "Radial mean intensity has not been calculated yet." + "Run PairDistributionFunction.calculate_Gr() or PairDistributionFunction.calculate_radial_mean() before plotting." + ) + + x = np.asarray(self.qq) + y = self._to_numpy(self.Ik) + x, y = self._apply_xrange(x, y, qmin, qmax) + + fig, ax = plt.subplots(figsize=figsize) + ax.plot(x, y, label="Radial Mean Intensity I(k)") + ax.set_xlabel("Scattering Vector q (1/Å)") + ax.set_ylabel("Intensity (a.u.)") + ax.set_title("Radial Mean Intensity vs Scattering Vector") + ax.legend() + ax.set_yscale("log") + plt.tight_layout() + + if returnfig: + return fig + else: + plt.show() def plot_background_fits( self, + qmin: float | None = None, + qmax: float | None = None, + rmin: float | None = None, # accepted for dispatcher compatibility, unused + rmax: float | None = None, # accepted for dispatcher compatibility, unused figsize: tuple[float, float] = (8, 4), returnfig: bool = False, ): """ - Stub for plotting background fit vs radial mean intensity. + Plotting background fit vs radial mean intensity. """ - raise NotImplementedError("plot_background_fits is not implemented yet.") + if self.Ik is None or self.bg is None: + raise RuntimeError( + "Radial mean intensity or background has not been calculated yet." + "Run PairDistributionFunction.calculate_Gr() or both calculate_radial_mean() and calculate_background() before plotting." + ) + + x = np.asarray(self.qq) + y1 = self._to_numpy(self.Ik) + x, y1 = self._apply_xrange(x, y1, qmin, qmax) + x = np.asarray(self.qq) + y2 = self._to_numpy(self.bg) + x, y2 = self._apply_xrange(x, y2, qmin, qmax) - def plot_sf_estimate( + fig, ax = plt.subplots(figsize=figsize) + ax.plot(x, y1, label="Radial Mean Intensity I(k)") + ax.plot(x, y2, label="Background B(k)", linestyle="--") + ax.set_xlabel("Scattering Vector q (1/Å)") + ax.set_ylabel("Intensity (a.u.)") + ax.set_title("Radial Mean Intensity and Background Fit") + ax.legend() + ax.set_yscale("log") + plt.tight_layout() + + if returnfig: + return fig + else: + plt.show() + + def plot_reduced_sf( self, + qmin: float | None = None, + qmax: float | None = None, + rmin: float | None = None, # accepted for dispatcher compatibility, unused + rmax: float | None = None, # accepted for dispatcher compatibility, unused figsize: tuple[float, float] = (8, 4), returnfig: bool = False, ): """ - Stub for plotting reduced structure factor S(k). + Plotting reduced structure factor F(k). """ - raise NotImplementedError("plot_sf_estimate is not implemented yet.") + if self.Fk_masked is None: + raise RuntimeError( + "Reduced structure factor F(k) has not been calculated yet." + "Run PairDistributionFunction.calculate_Gr() before plotting." + ) + + Fk = getattr(self, "Fk_damped", None) + if Fk is None: + Fk = self.Fk_masked + + x = np.asarray(self.qq) + y = self._to_numpy(Fk) + x, y = self._apply_xrange(x, y, qmin, qmax) + + fig, ax = plt.subplots(figsize=figsize) + ax.plot(x, y, label="Reduced Structure Factor F(k)") + ax.set_xlabel("Scattering Vector q (1/Å)") + ax.set_ylabel("Reduced Structure Factor F(k)") + plt.tight_layout() + + if returnfig: + return fig + else: + plt.show() def plot_reduced_pdf( self, + qmin: float | None = None, # accepted for dispatcher compatibility, unused + qmax: float | None = None, # accepted for dispatcher compatibility, unused + rmin: float | None = None, + rmax: float | None = None, + padding_frac: float = 0.1, figsize: tuple[float, float] = (8, 4), returnfig: bool = False, ): """ - Stub for plotting reduced PDF g(r). + Plotting reduced PDF g(r). """ - raise NotImplementedError("plot_reduced_pdf is not implemented yet.") + if self._reduced_pdf is None: + raise RuntimeError( + "Reduced PDF has not been calculated yet." + "Run PairDistributionFunction.calculate_Gr() before plotting." + ) + Gr = self.reduced_pdf_damped if self.reduced_pdf_damped is not None else self._reduced_pdf + + x = self._to_numpy(self._r) + y = self._to_numpy(Gr) + x, y = self._apply_xrange(x, y, rmin, rmax) + + # Find radial value of primary peak and trough for y-limits + # Filter out NaN and Inf values to avoid plot errors + valid_mask = np.isfinite(y) + if np.any(valid_mask): + y_valid = y[valid_mask] + y_max = np.max(y_valid) + y_min = np.min(y_valid) + else: + # Fallback if all values are invalid + y_max = 1.0 + y_min = -1.0 + yrange = y_max - y_min + pad = padding_frac * yrange + + fig, ax = plt.subplots(figsize=figsize) + ax.plot(x, y, label="Reduced Pair Distribution Function G(r)") + ax.set_xlabel("Radial Distance r (Å)") + ax.set_ylabel("Reduced Pair Distribution Function G(r)") + ax.set_ylim(y_min - pad, y_max + pad) + plt.tight_layout() + + if returnfig: + return fig + else: + plt.show() def plot_pdf( self, + qmin: float | None = None, # accepted for dispatcher compatibility, unused + qmax: float | None = None, # accepted for dispatcher compatibility, unused + rmin: float | None = None, + rmax: float | None = None, + padding_frac: float = 0.1, figsize: tuple[float, float] = (8, 4), returnfig: bool = False, ): """ - Stub for plotting full PDF G(r). + Plotting pair distribution function g(r). """ - raise NotImplementedError("plot_pdf is not implemented yet.") + if self._reduced_pdf is None or self._pdf is None: + raise RuntimeError( + "PDF has not been calculated yet." + "Run PairDistributionFunction.calculate_gr() before plotting." + ) + + x = self._to_numpy(self._r) + y = self._to_numpy(self._pdf) + x, y = self._apply_xrange(x, y, rmin, rmax) + + # Find radial value of primary peak + # Filter out NaN and Inf values to avoid plot errors + valid_mask = np.isfinite(y) + if np.any(valid_mask): + y_valid = y[valid_mask] + y_max = np.max(y_valid) + y_min = np.min(y_valid) + else: + # Fallback if all values are invalid + y_max = 1.0 + y_min = -1.0 + yrange = y_max - y_min + pad = padding_frac * yrange + + fig, ax = plt.subplots(figsize=figsize) + ax.plot(x, y, label="Pair Distribution Function g(r)") + ax.set_xlabel("Radial Distance r (Å)") + ax.set_ylabel("Pair Distribution Function g(r)") + ax.set_ylim(y_min - pad, y_max + pad) + plt.tight_layout() + + if returnfig: + return fig + else: + plt.show() + + def plot_oscillation_damping( + self, + qmin: float | None = None, # accepted for dispatcher compatibility, unused + qmax: float | None = None, # accepted for dispatcher compatibility, unused + rmin: float | None = None, + rmax: float | None = None, + padding_frac: float = 0.1, + figsize: tuple[float, float] = (8, 4), + returnfig: bool = False, + ): + if self.Fk_masked is None or self.Fk_damped is None or self.reduced_pdf_damped is None: + raise RuntimeError( + "Oscillation damping data not available. " + "Run calculate_Gr(damp_origin_oscillations=True) first." + ) + + k = np.asarray(self.qq) + + # Convert torch tensors to numpy for plotting + Fk_masked = self._to_numpy(self.Fk_masked) + Fk_damped = self._to_numpy(self.Fk_damped) + r = self._to_numpy(self._r) + reduced_pdf = self._to_numpy(self._reduced_pdf) + reduced_pdf_damped = self._to_numpy(self.reduced_pdf_damped) + + fig, axes = plt.subplots(2, 2, figsize=figsize) + + # F(k) + axS_top = axes[0, 0] + axS_res = axes[1, 0] + axS_top.plot(k, Fk_masked, label="F_obs(k)", color="gray") + axS_top.plot(k, Fk_damped, label="F_cor(k)", color="red") + axS_top.set_xlabel("k (A$^{-1}$)") + axS_top.set_ylabel("F(k)") + axS_top.legend() + + axS_res.plot(k, Fk_damped - Fk_masked, color="blue") + axS_res.set_xlabel("k (A$^{-1}$)") + axS_res.set_ylabel("F_cor - F_obs") + + # G(r) + axG_top = axes[0, 1] + axG_res = axes[1, 1] + axG_top.plot(r, reduced_pdf, label="G_obs(r)", color="gray") + axG_top.plot(r, reduced_pdf_damped, label="G_cor(r)", color="red") + axG_top.set_xlabel("r (A)") + axG_top.set_ylabel("G(r)") + axG_top.legend() + + axG_res.plot(r, reduced_pdf_damped - reduced_pdf, color="blue") + axG_res.set_xlabel("r (A)") + axG_res.set_ylabel("G_cor - G_obs") + + fig.tight_layout() + + if returnfig: + return fig + else: + plt.show() diff --git a/tests/diffraction/test_polar.py b/tests/diffraction/test_polar.py new file mode 100644 index 00000000..4dc5a005 --- /dev/null +++ b/tests/diffraction/test_polar.py @@ -0,0 +1,383 @@ +import numpy as np +import pytest + +from quantem.core.datastructures.dataset2d import Dataset2d +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.datastructures.polar4dstem import Polar4dstem, auto_origin_id +from quantem.diffraction.polar import PairDistributionFunction + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def synthetic_diffraction_pattern(): + """Create a synthetic diffraction pattern with concentric rings.""" + ny, nx = 256, 256 + y, x = np.ogrid[:ny, :nx] + cy, cx = (ny - 1) / 2.0, (nx - 1) / 2.0 + + # Create rings with Gaussian profiles at specific radii + pattern = np.zeros((ny, nx), dtype=np.float32) + ring_radii = [10, 20, 30, 40] + r = np.sqrt((y - cy) ** 2 + (x - cx) ** 2) + for radius in ring_radii: + pattern += 100 * np.exp(-((r - radius) ** 2) / (2 * 2**2)) + # central beam + pattern += 1000 * np.exp(-(r**2) / (2 * 3**2)) + # noise + rng = np.random.default_rng(42) + pattern += rng.poisson(5, size=(ny, nx)) + + return pattern.astype(np.float32) + + +@pytest.fixture +def synthetic_4dstem_dataset(synthetic_diffraction_pattern): + """Create a synthetic 4D-STEM dataset with 3x3 scan.""" + scan_y, scan_x = 3, 3 + ny, nx = synthetic_diffraction_pattern.shape + + array_4d = np.zeros((scan_y, scan_x, ny, nx), dtype=np.float32) + for iy in range(scan_y): + for ix in range(scan_x): + # Add slight variations + rng = np.random.default_rng(42 + iy * scan_x + ix) + variation = 1.0 + 0.1 * rng.standard_normal() + array_4d[iy, ix] = synthetic_diffraction_pattern * variation + + return Dataset4dstem.from_array( + array=array_4d, + name="test_4dstem", + origin=(0, 0, 0, 0), + sampling=(1.0, 1.0, 0.015, 0.015), + units=["nm", "nm", "1/Angstrom", "1/Angstrom"], + signal_units="counts", + ) + + +@pytest.fixture +def synthetic_dataset2d(synthetic_diffraction_pattern): + """Create a synthetic 2D diffraction dataset.""" + return Dataset2d.from_array( + array=synthetic_diffraction_pattern, + name="test_2d_diffraction", + origin=(0, 0), + sampling=(0.015, 0.015), + units=["1/Angstrom", "1/Angstrom"], + signal_units="counts", + ) + + +# ============================================================================ +# Test PairDistributionFunction Construction +# ============================================================================ + + +class TestPairDistributionFunctionConstruction: + """Test PairDistributionFunction initialization from various input types.""" + + def test_from_data_with_dataset4dstem(self, synthetic_4dstem_dataset): + """Test construction from a Dataset4dstem object.""" + pdf = PairDistributionFunction.from_data( + synthetic_4dstem_dataset, + find_origin=False, + ) + assert isinstance(pdf.polar, Polar4dstem) + assert pdf.input_data is synthetic_4dstem_dataset + assert pdf.polar.shape[0] == 3 # scan_y + assert pdf.polar.shape[1] == 3 # scan_x + assert pdf.polar.shape[2] == 180 # num_annular_bins + + def test_from_data_with_invalid_array_raises(self): + """Test that arrays with wrong dimensions raise ValueError.""" + array_1d = np.random.rand(100) + with pytest.raises(ValueError, match="only supports 2D or 4D arrays"): + PairDistributionFunction.from_data(array_1d) + + def test_direct_init_without_token_raises(self, synthetic_dataset2d): + """Test that direct __init__ without token raises RuntimeError.""" + pdf_valid = PairDistributionFunction.from_data(synthetic_dataset2d, find_origin=False) + with pytest.raises(RuntimeError, match="Use PairDistributionFunction.from_data"): + PairDistributionFunction(polar=pdf_valid.polar, device="cpu") + + def test_find_origin(self, synthetic_4dstem_dataset): + """Test automatic origin finding.""" + origin_array = auto_origin_id( + synthetic_4dstem_dataset, + ) + assert origin_array.shape == (3, 3, 2) # (scan_y, scan_x, 2) + expected_center = 127.5 + for iy in range(3): + for ix in range(3): + row, col = origin_array[iy, ix] + assert abs(row - expected_center) < 1 + assert abs(col - expected_center) < 1 + + +# ============================================================================ +# Test Polar Transform +# ============================================================================ + + +class TestPolarTransform: + """Test polar coordinate transformation.""" + + def test_polar_transform_basic(self, synthetic_4dstem_dataset): + """Test basic polar transformation.""" + polar = synthetic_4dstem_dataset.polar_transform() + assert isinstance(polar, Polar4dstem) + assert polar.shape[0] == 3 # scan_y + assert polar.shape[1] == 3 # scan_x + assert polar.shape[2] == 180 # num_annular_bins + assert polar.shape[3] > 0 # radial bins + + def test_polar_transform_single_origin(self, synthetic_4dstem_dataset): + """Test polar transformation with single origin broadcast to all positions.""" + origin = np.array([128.0, 128.0]) + polar = synthetic_4dstem_dataset.polar_transform( + origin_array=origin, + ) + assert isinstance(polar, Polar4dstem) + + def test_polar_transform_radial_range(self, synthetic_4dstem_dataset): + """Test polar transformation with custom radial range.""" + polar = synthetic_4dstem_dataset.polar_transform( + radial_min=5.0, + radial_max=50.0, + radial_step=2.0, + ) + assert isinstance(polar, Polar4dstem) + # Check that radial dimension matches expected size + expected_n_r = int(np.ceil((50.0 - 5.0) / 2.0)) + assert polar.shape[3] == expected_n_r + + def test_polar_transform_scan_pos(self, synthetic_4dstem_dataset): + """Test polar transformation for a single scan position.""" + polar_2d = synthetic_4dstem_dataset.polar_transform( + scan_pos=(0, 0), + ) + # should return 2D tensor (phi, r) + assert polar_2d.ndim == 2 + assert polar_2d.shape[0] == 180 # num_annular_bins + + +# ============================================================================ +# Test Radial Mean Calculation +# ============================================================================ + + +class TestRadialMeanCalculation: + """Test radial mean intensity calculation.""" + + def test_calculate_radial_mean_with_mask(self, synthetic_4dstem_dataset): + """Test radial mean calculation with real-space mask.""" + pdf = PairDistributionFunction.from_data( + synthetic_4dstem_dataset, + find_origin=False, + ) + mask = np.zeros((3, 3), dtype=bool) + mask[0:2, 0:2] = True + radial_mean = pdf.calculate_radial_mean( + mask_realspace=mask, + returnval=True, + ) + assert radial_mean is not None + + +# ============================================================================ +# Test Background Fitting +# ============================================================================ + + +class TestBackgroundFitting: + """Test background fitting.""" + + def test_fit_bg_basic(self, synthetic_dataset2d): + """Test basic background fitting.""" + pdf = PairDistributionFunction.from_data( + synthetic_dataset2d, + find_origin=False, + ) + Ik = pdf.calculate_radial_mean(returnval=True) + k = pdf._to_torch(np.asarray(pdf.qq)) + kmin, kmax = float(k.min()), float(k.max()) + bg, f = pdf.fit_bg(Ik, kmin=kmin * 0.1, kmax=kmax * 0.9) + assert bg.shape == Ik.shape + assert f.shape == Ik.shape + # Check that background is positive + assert (bg >= 0).all() + + +# ============================================================================ +# Test PDF Calculation +# ============================================================================ + + +class TestPDFCalculation: + """Test the PDF calculation pipeline.""" + + def test_calculate_Gr_with_bandpass(self, synthetic_dataset2d): + """Test PDF calculation with bandpass filtering.""" + pdf = PairDistributionFunction.from_data( + synthetic_dataset2d, + find_origin=False, + ) + pdf.calculate_Gr( + k_min=0.1, + k_max=2.0, + k_lowpass=0.02, + k_highpass=0.001, + ) + assert pdf.reduced_pdf is not None + + def test_calculate_Gr_with_mask(self, synthetic_4dstem_dataset): + """Test PDF calculation with real-space mask.""" + pdf = PairDistributionFunction.from_data( + synthetic_4dstem_dataset, + find_origin=False, + ) + mask = np.zeros((3, 3), dtype=bool) + mask[0:2, 0:2] = True + pdf.calculate_Gr( + k_min=0.1, + k_max=2.0, + mask_realspace=mask, + ) + assert pdf.reduced_pdf is not None + + def test_calculate_gr_requires_Gr(self, synthetic_dataset2d): + """Test that calculate_gr raises if calculate_Gr has not been run.""" + pdf = PairDistributionFunction.from_data( + synthetic_dataset2d, + find_origin=False, + ) + with pytest.raises(RuntimeError, match="Reduced PDF not computed"): + pdf.calculate_gr(density=0.05) + + def test_calculate_gr_estimates_density(self, synthetic_dataset2d): + """Test that calculate_gr estimates density when none is provided.""" + pdf = PairDistributionFunction.from_data( + synthetic_dataset2d, + find_origin=False, + ) + pdf.calculate_Gr(k_min=0.1, k_max=2.0) + results = pdf.calculate_gr(returnval=True) + assert results is not None + r, gr = results + assert isinstance(gr, np.ndarray) + assert len(gr) == len(r) + assert pdf.rho0 > 0 + + def test_estimate_density_requires_Gr(self, synthetic_dataset2d): + """Test that estimate_density requires prior calculate_Gr call.""" + pdf = PairDistributionFunction.from_data( + synthetic_dataset2d, + find_origin=False, + ) + with pytest.raises( + RuntimeError, match="depends on Sk, reduced_pdf, and r from calculate_Gr" + ): + pdf.estimate_density() + + +# ============================================================================ +# Integration Workflows +# ============================================================================ + + +class TestIntegrationWorkflows: + """Test complete end-to-end workflows.""" + + def test_complete_pdf_workflow_2d(self, synthetic_dataset2d): + """Test: 2D diffraction → polar transform → G(r) → g(r).""" + pdf = PairDistributionFunction.from_data( + synthetic_dataset2d, + find_origin=False, + ) + Gr_results = pdf.calculate_Gr( + k_min=0.1, + k_max=2.0, + r_min=0.0, + r_max=10.0, + r_step=0.05, + returnval=True, + ) + assert Gr_results is not None + r, Gr = Gr_results + assert not np.isnan(r).any() + assert not np.isnan(Gr).any() + assert not np.isinf(Gr).any() + assert len(r) > 0 + assert len(Gr) == len(r) + gr_results = pdf.calculate_gr( + density=0.05, + returnval=True, + ) + assert gr_results is not None + r_gr, gr = gr_results + assert not np.isnan(gr).any() + assert not np.isinf(gr).any() + assert len(gr) == len(r_gr) + + def test_complete_pdf_workflow_4dstem(self, synthetic_4dstem_dataset): + """Test: 4D-STEM → origin finding → polar transform → G(r).""" + pdf = PairDistributionFunction.from_data( + synthetic_4dstem_dataset, + find_origin=True, + ) + mask = np.zeros((3, 3), dtype=bool) + mask[0:2, 0:2] = True + pdf.calculate_Gr( + k_min=0.1, + k_max=2.0, + mask_realspace=mask, + ) + assert pdf.reduced_pdf is not None + assert not np.isnan(pdf.reduced_pdf).any() + assert not np.isinf(pdf.reduced_pdf).any() + + def test_polar_transform_input_types(self, synthetic_diffraction_pattern): + """Test polar_transform works with numpy array, Dataset2d, Dataset4dstem.""" + # Test with Dataset2d + ds2 = Dataset2d.from_array( + array=synthetic_diffraction_pattern, + name="test", + ) + pdf_ds2 = PairDistributionFunction.from_data( + ds2, + find_origin=False, + ) + assert pdf_ds2.polar.shape[2] == 180 + + # Test with Dataset4dstem + array_4d = synthetic_diffraction_pattern[None, None, :, :] # (1, 1, ny, nx) + ds4 = Dataset4dstem.from_array(array_4d, name="test") + pdf_ds4 = PairDistributionFunction.from_data( + ds4, + find_origin=False, + ) + assert pdf_ds4.polar.shape[2] == 180 + assert pdf_ds2.polar.shape == pdf_ds4.polar.shape + + def test_density_estimation_workflow(self, synthetic_dataset2d): + """Test: G(r) calculation → density estimation → g(r) calculation.""" + pdf = PairDistributionFunction.from_data( + synthetic_dataset2d, + find_origin=False, + ) + pdf.calculate_Gr(k_min=0.1, k_max=2.0) + rho0, Fk_damped, G_cor = pdf.estimate_density( + max_iter=5, + tol_percent=1.0, + ) + assert rho0 > 0 + assert np.isfinite(rho0) + results = pdf.calculate_gr( + density=rho0, + returnval=True, + ) + assert results is not None + r, gr = results + assert not np.isnan(gr).any()