diff --git a/src/quantem/core/utils/augment_dp.py b/src/quantem/core/utils/augment_dp.py index b7b55f91..44a92022 100644 --- a/src/quantem/core/utils/augment_dp.py +++ b/src/quantem/core/utils/augment_dp.py @@ -1,4 +1,5 @@ import os +import warnings from typing import TYPE_CHECKING, Union import numpy as np @@ -20,17 +21,17 @@ ArrayLike = Union[np.ndarray, "torch.Tensor"] + # TODO # add dark background # add gaussian noise - - class DPAugmentor(RNGMixin): def __init__( self, add_bkg: bool = False, bkg_weight: list[float] | float = [0.001, 0.05], bkg_q: list[float] | float = [0.01, 0.1], + apply_background_to_label: list[bool] | None = None, add_shot: bool = False, e_dose: list[float] | float = [1e4, 1e7], add_shift: bool = False, @@ -41,6 +42,9 @@ def __init__( add_ellipticity_to_label: bool = True, add_salt_and_pepper: bool = False, salt_and_pepper: list[float] | float = [0, 5e-4], + add_gaussian_noise: bool = False, + gaussian_noise_mu: float = 0.0, + gaussian_noise_std: float = 1e-5, add_scale: bool = False, scale_factor: list[float] | float = [0.9, 1.1], add_blur: bool = False, @@ -51,6 +55,9 @@ def __init__( log_file: os.PathLike | None = None, rng: np.random.Generator | int | None = None, device: str = "cpu", + add_aperture: bool = False, + radius_factor: list[float] | float = [0.8, 1], + aperture_shift: list[float] | float = [0, 10], ): """ Initialize diffraction pattern augmentor with configurable transformations. @@ -63,7 +70,9 @@ def __init__( Range for background weight (fraction of total intensity). bkg_q : list[float] | float, default=[0.01, 0.1] Range for plasmon scattering parameter q₀ in 1/(q² + q₀²) form factor. - + apply_background_to_label: list[bool] | None, default=None + Flag for whether background should be applied to labels, and which ones based on 1/0 list. + List of 1/0 for if background should be applied to label. None if no application. add_shot : bool, default=False Enable Poisson shot noise based on electron dose. e_dose : list[float] | float, default=[1e4, 1e7] @@ -89,6 +98,15 @@ def __init__( salt_and_pepper : list[float] | float, default=[0, 5e-4] Range for fraction of pixels affected by salt and pepper noise. + add_gaussian_noise : bool, default=False + Enable gaussian noise. + gaussian_noise_mu : float, default=0.0 + Mean for gaussian noise distribution. Should be 0 for scientifically accurate representation. + Scaled by electron dose. So value of 0.1 represents mean = 10% of electron dose. + gaussian_noise_std : float, defualt=1e-5 + Standard deviation for gaussian noise distribution. + Scaled by electron dose. So value of 0.1 represents std. dev. = 10% of electron dose. + add_scale : bool, default=False Enable uniform scaling of the diffraction pattern. scale_factor : list[float] | float, default=[0.9, 1.1] @@ -114,9 +132,19 @@ def __init__( device : str, default="cpu" Device for computations ("cpu", "cuda", "cuda:0", etc.). + add_aperture : bool, default=False + Enable circular aperture mask to simulate objective aperture effects. + radius_factor : list[float] | float, default=[0.8, 1] + Range for aperture radius as fraction of maximum image radius (distance from + center to corner). Values < 1 create vignetted diffraction patterns. The mask + is centered at (height//2, width//2) + aperture_shift. + aperture_shift : list[float] | float, default=[0, 10] + Range for random shift of aperture center in pixels (applied with random sign + to both x and y). Simulates misalignment of the objective aperture. + Notes ----- - - Augmentations are applied in order: flipshift → background → elastic → + - Augmentations are applied in order: flipshift → elastic → background → shot noise → blur → salt & pepper - For labels, only geometric transforms (flipshift, elastic) are applied - Ellipticity creates anisotropic scaling via exx, eyy, exy parameters @@ -126,88 +154,24 @@ def __init__( self._setup_device(device) self.log_file = log_file - self.set_params( - add_bkg, - bkg_weight, - bkg_q, - add_shot, - e_dose, - add_shift, - xshift, - yshift, - add_ellipticity, - ellipticity_scale, - add_ellipticity_to_label, - add_salt_and_pepper, - salt_and_pepper, - add_scale, - scale_factor, - add_blur, - blur_sigma, - add_flipshift, - free_rotation, - rotation_range, - ) - self.generate_params() - self._init_log_file() - - def _setup_device(self, device: str) -> None: - if device == "gpu" or device.startswith("cuda"): - if not config.get("has_torch"): - raise RuntimeError("torch required for GPU operations but not available") - self.device = device if device.startswith("cuda") else "cuda" - self.use_torch = True - else: - self.device = "cpu" - self.use_torch = False - - if hasattr(self, "_rng_seed") and self._rng_seed is not None: - self._rng_to_device(self.device) - - def _init_log_file(self) -> None: - if self.log_file is not None: - with open(self.log_file, "a") as f: - f.write( - "bkg_weight,bkg_q,e_dose,xshift,yshift,exx,eyy,exy," - "scale_factor,flip_horizontal,flip_vertical,rotation_angle," - "blur_sigma,salt_and_pepper,rng_seed\n" - ) - - def set_params( - self, - add_bkg: bool = False, - bkg_weight: list[float] | float = [0.01, 0.1], - bkg_q: list[float] | float = [0.01, 0.1], - add_shot: bool = False, - e_dose: list[float] | float = [1e5, 1e10], - add_shift: bool = False, - xshift: list[float] | float = [0, 10], - yshift: list[float] | float = [0, 10], - add_ellipticity: bool = False, - ellipticity_scale: list[float] | float = [0, 0.15], - add_ellipticity_to_label: bool = True, - add_salt_and_pepper: bool = False, - salt_and_pepper: list[float] | float = [0, 1e-3], - add_scale: bool = False, - scale_factor: list[float] | float = [0.9, 1.1], - add_blur: bool = False, - blur_sigma: list[float] | float = [0.0, 1.5], - add_flipshift: bool = False, - free_rotation: bool = False, - rotation_range: list[float] | float = [-180, 180], - ) -> None: + # Setting parameters self.add_bkg = add_bkg self.add_shot = add_shot self.add_shift = add_shift self.add_ellipticity = add_ellipticity - self.add_ellipticity_to_label = add_ellipticity_to_label + self.add_ellipticity_to_label = add_ellipticity_to_label or [] self.add_salt_and_pepper = add_salt_and_pepper + self.add_gaussian_noise = add_gaussian_noise + self.gaussian_noise_mu = gaussian_noise_mu + self.gaussian_noise_std = gaussian_noise_std self.add_scale = add_scale self.add_blur = add_blur self.add_flipshift = add_flipshift + self.add_aperture = add_aperture self._bkg_weight_range = self._check_input(bkg_weight) if add_bkg else [0, 0] self._bkg_q_range = self._check_input(bkg_q) if add_bkg else [0, 0] + self.apply_background_to_label = apply_background_to_label self._e_dose_range = self._check_input(e_dose) if add_shot else [np.inf, np.inf] self._xshift_range = self._check_input(xshift) if add_shift else [0, 0] self._yshift_range = self._check_input(yshift) if add_shift else [0, 0] @@ -223,6 +187,35 @@ def set_params( self.free_rotation = free_rotation self._rotation_range = self._check_input(rotation_range) if add_flipshift else [0, 0] + self._radius_range = self._check_input(radius_factor) if add_aperture else [0, 0] + self._aptshift_range = self._check_input(aperture_shift) if add_aperture else [0, 0] + + # Generate parameters from set parameters + self.generate_params() + self._init_log_file() + + def _setup_device(self, device: str) -> None: + if device == "gpu" or device.startswith("cuda"): + if not config.get("has_torch"): + raise RuntimeError("torch required for GPU operations but not available") + self.device = device if device.startswith("cuda") else "cuda" + self.use_torch = True + else: + self.device = "cpu" + self.use_torch = False + + if hasattr(self, "_rng_seed") and self._rng_seed is not None: + self._rng_to_device(self.device) + + def _init_log_file(self) -> None: + if self.log_file is not None: + with open(self.log_file, "a") as f: + f.write( + "bkg_weight,bkg_q,apply_background_to_label,e_dose,xshift,yshift,exx,eyy,exy," + "gaussian_noise_mu,gaussian_noise_std,scale_factor,flip_horizontal,flip_vertical," + "rotation_angle,blur_sigma,salt_and_pepper,rng_seed\n" + ) + def generate_params(self) -> None: self.bkg_weight = self._uniform_or_zero(self._bkg_weight_range, self.add_bkg) self.bkg_q = self._uniform_or_zero(self._bkg_q_range, self.add_bkg) @@ -233,6 +226,8 @@ def generate_params(self) -> None: self.blur_sigma = self._uniform_or_zero(self._blur_range, self.add_blur) self.xshift = self._uniform_with_sign(self._xshift_range, self.add_shift) self.yshift = self._uniform_with_sign(self._yshift_range, self.add_shift) + self.xshiftapt = self._uniform_with_sign(self._aptshift_range, self.add_aperture) + self.yshiftapt = self._uniform_with_sign(self._aptshift_range, self.add_aperture) self._generate_ellipticity_params() self._generate_flipshift_params() @@ -241,6 +236,11 @@ def generate_params(self) -> None: else: self.scale_factor = 0 + if self.add_aperture: + self.radius_factor = self.rng.uniform(self._radius_range[0], self._radius_range[1]) + else: + self.radius_factor = 0 + def _uniform_or_zero(self, range_vals: list, enabled: bool) -> float: return self.rng.uniform(range_vals[0], range_vals[1]) if enabled else 0 @@ -310,6 +310,7 @@ def print_params(self, print_all: bool = False) -> None: f"Flip: H={self.flip_horizontal}, V={self.flip_vertical}, Rot: {self.rotation_angle:.1f}°", ), ("Salt & pepper", self.add_salt_and_pepper, f"Amount: {self.salt_and_pepper:.2e}"), + ("Gaussian noise", self.add_gaussian_noise, f"Mean: {self.gaussian_noise_mu:.2e}", f"Std: {self.gaussian_noise_std:.2e}"), ("Gaussian blur", self.add_blur, f"Sigma: {self.blur_sigma:.2f}"), ] @@ -360,17 +361,32 @@ def _augment_stack( if probe_stack is not None and probe_stack.shape[0] != batch_size: raise ValueError(f"Probe stack size {probe_stack.shape[0]} != DP size {batch_size}") - if label_stack is not None and label_stack.shape[0] != batch_size: + # Make exception for batch_size of 1 + if batch_size == 1 and len(label_stack.shape) == 3: + pass + elif label_stack is not None and label_stack.shape[0] != batch_size: raise ValueError(f"Label stack size {label_stack.shape[0]} != DP size {batch_size}") augmented_dps = [] augmented_labels = [] if label_stack is not None else None - for i in tqdm(range(batch_size), desc="augmenting"): + # Create iterator with condition for batch_size of 1 + iterator = tqdm(range(batch_size), desc="augmenting") if batch_size > 1 else range(batch_size) + for i in iterator: dp_single = dp_stack[i] probe_single = probe_stack[i] if probe_stack is not None else None - label_single = label_stack[i] if label_stack is not None else None - + + # Check for multichannel labels + if label_stack is not None: + if batch_size == 1 and len(label_stack.shape) == 3: + # Single image with multichannel labels + label_single = label_stack # Use entire multichannel label + else: + # If multiple images take labels for current iterant + label_single = label_stack[i] + else: + label_single = None + if label_single is not None: aug_dp, aug_label = self._augment_single(dp_single, probe_single, label_single) augmented_dps.append(aug_dp) @@ -382,13 +398,21 @@ def _augment_stack( if self.use_torch: stacked_dps = torch.stack(augmented_dps) # type: ignore if augmented_labels is not None: - stacked_labels = torch.stack(augmented_labels) # type: ignore + # Check for batch size of 1 + if batch_size == 1 and len(label_stack.shape) == 3: + stacked_labels = augmented_labels[0] # If multichannel just return, don't stack + else: + stacked_labels = torch.stack(augmented_labels) # type: ignore return stacked_dps, stacked_labels return stacked_dps else: stacked_dps = np.stack(augmented_dps) if augmented_labels is not None: - stacked_labels = np.stack(augmented_labels) + # Check for batch size of 1 + if batch_size == 1 and len(label_stack.shape) == 3: + stacked_labels = augmented_labels[0] # If multichannel just return, don't stack + else: + stacked_labels = np.stack(augmented_labels) return stacked_dps, stacked_labels return stacked_dps @@ -401,15 +425,34 @@ def _augment_single( if self.add_flipshift: result = self._apply_flipshift(result) if transformed_label is not None: - transformed_label = self._apply_flipshift(transformed_label) - if self.add_bkg: - result = self._apply_bkg(result, probe) + # Check if label is multichannel + if len(transformed_label.shape) == 3: + transformed_label = self._apply_flipshift_to_multichannel_label(label) + else: + transformed_label = self._apply_flipshift(label) + if self.add_ellipticity or self.add_shift or self.add_scale: result = self._apply_elastic(result) if transformed_label is not None: - transformed_label = self._apply_elastic_to_label(transformed_label) + # Check if label is multichannel + if len(transformed_label.shape) == 3: + transformed_label = self._apply_elastic_to_multichannel_label(transformed_label) + else: + transformed_label = self._apply_elastic_to_label(transformed_label) + + if self.add_bkg: + result = self._apply_bkg(result, probe) + if transformed_label is not None and self.apply_background_to_label is not None: + if len(self.apply_background_to_label) > 0: + if len(transformed_label.shape) == 3: + transformed_label = self._apply_bkg_to_multichannel_label(transformed_label, probe) + + if self.add_aperture: # currently input can only be Tensor + result = self._apply_aperture(result) if self.add_shot: result = self._apply_shot(result) + if self.add_gaussian_noise: + result = self._apply_gaussian_noise(result) if self.add_blur: result = self._apply_blur(result) if self.add_salt_and_pepper: @@ -461,6 +504,34 @@ def _maybe_switch_to_torch( self.use_torch = True self._rng_to_device(self.device) + def _apply_flipshift_to_multichannel_label(self, label: ArrayLike) -> ArrayLike: + """Apply flipshift to multichannel label""" + if len(label.shape) == 3: # Multichannel (C, H, W) + transformed_channels = [] + for c in range(label.shape[0]): + transformed_channels.append(self._apply_flipshift(label[c])) + if self.use_torch: + return torch.stack(transformed_channels) + else: + return np.stack(transformed_channels) + else: + # Single channel label + return self._apply_flipshift(label) + + def _apply_elastic_to_multichannel_label(self, label: ArrayLike) -> ArrayLike: + """Apply elastic transforms to multichannel label""" + if len(label.shape) == 3: # Multichannel (C, H, W) + transformed_channels = [] + for c in range(label.shape[0]): + transformed_channels.append(self._apply_elastic_to_label(label[c])) + if self.use_torch: + return torch.stack(transformed_channels) + else: + return np.stack(transformed_channels) + else: + # Single channel label + return self._apply_elastic_to_label(label) + def _apply_shot(self, inputs: ArrayLike) -> ArrayLike: """Apply Poisson shot noise""" if self.use_torch: @@ -468,14 +539,35 @@ def _apply_shot(self, inputs: ArrayLike) -> ArrayLike: offset = image.min() image = (image - offset) / (image - offset).sum() return torch.poisson(image * self.e_dose, generator=self._rng_torch) + offset + # Below version preserves total intensity + # sum_int = (image - offset).sum() + # image = (image - offset) / sum_int + # return torch.poisson(image * self.e_dose, generator=self._rng_torch) * sum_int / self.e_dose + offset else: image = np.array(inputs) offset = image.min() image = (image - offset) / (image - offset).sum() return self.rng.poisson(image * self.e_dose) + offset + def _apply_aperture(self, inputs: "torch.Tensor") -> "torch.Tensor": + height, width = inputs.shape + device = inputs.device + y, x = torch.meshgrid( + torch.arange(height, dtype=torch.float32, device=device), + torch.arange(width, dtype=torch.float32, device=device), + indexing="ij", + ) + y_center, x_center = height // 2, width // 2 + y = y.clone() - y_center + self.yshiftapt + x = x.clone() - x_center + self.xshiftapt + r = torch.sqrt(x**2+y**2) + + aperture_mask = (r <= self.radius_factor*np.sqrt(y_center**2+x_center**2)).float() + output = inputs * aperture_mask + return output + def _apply_elastic(self, inputs: ArrayLike) -> ArrayLike: - """Apply elastic transformations (scaling, rotation, translation)""" + """Apply elastic transformations (scaling, translation)""" if self.use_torch: return self._apply_elastic_torch(inputs) # type: ignore else: @@ -502,7 +594,7 @@ def _apply_elastic_torch(self, inputs: "torch.Tensor") -> "torch.Tensor": if self.add_shift: x_new += self.xshift y_new += self.yshift - + x_norm = 2.0 * x_new / (width - 1) - 1.0 y_norm = 2.0 * y_new / (height - 1) - 1.0 grid = torch.stack([x_norm, y_norm], dim=-1).unsqueeze(0) @@ -540,7 +632,10 @@ def _apply_bkg(self, inputs: ArrayLike, probe: ArrayLike | None = None) -> Array qx = af.view(af.sort(af.fftfreq(height, 0.1, like=inputs), axis=0), (-1, 1)) qy = af.view(af.sort(af.fftfreq(width, 0.1, like=inputs), axis=0), (1, -1)) - CBEDbg = 1.0 / (qx**2 + qy**2 + self.bkg_q**2) # Plasmon form factor: 1/(q² + q₀²) + qxc = self.yshift / (height*0.1) + qyc = self.xshift / (width*0.1) + + CBEDbg = 1.0 / ((qx+qxc)**2 + (qy+qyc)**2 + self.bkg_q**2) # Plasmon form factor: 1/(q² + q₀²) CBEDbg = CBEDbg.squeeze() / af.sum(CBEDbg.squeeze()) if probe is not None: @@ -551,6 +646,31 @@ def _apply_bkg(self, inputs: ArrayLike, probe: ArrayLike | None = None) -> Array inputs_float = af.as_type(inputs, torch.float32 if self.use_torch else np.float32) return inputs_float * (1 - self.bkg_weight) + CBEDbgConv.real * self.bkg_weight + def _apply_bkg_to_multichannel_label(self, label: ArrayLike, probe: ArrayLike | None = None) -> ArrayLike: + """Apply background to specified channels of multichannel label""" + if len(label.shape) != 3: + warnings.warn(f"Expected shape (C,H,W), got {label.shape}. Returning unchanged.", stacklevel=2) + return label + + if len(self.background_label_application) == 0: + warnings.warn("background_label_application is empty. Returning unchanged.", stacklevel=2) + return label + + # Process each channel + result_channels = [] + for c in range(label.shape[0]): + if c < len(self.apply_background_to_label) and self.apply_background_to_label[c]: + # Apply background to this channel per apply_background_to_label + result_channels.append(self._apply_bkg(label[c], probe)) + else: + # Keep channel as-is + result_channels.append(label[c]) + + if self.use_torch: + return torch.stack(result_channels) + else: + return np.stack(result_channels) + def _apply_blur(self, inputs: ArrayLike) -> ArrayLike: """Apply Gaussian blur""" if self.use_torch: @@ -611,14 +731,33 @@ def _get_salt_and_pepper( out[flipped & ~salted] = pepper_val return out + def _apply_gaussian_noise(self, inputs: ArrayLike) -> ArrayLike: + # Constant background applied to everything, scaled by electron dose + # Gaussian uniform to whole image, clipped to 0 + # Just camera noise, electronic noise + # Just some random scale value (std 5 e- for example, mean is std, then clip. Makes it so gaussian shifted so half isn't negative) + mean = self.gaussian_noise_mu * self.e_dose if self.add_shot else self.gaussian_noise_mu + std = self.gaussian_noise_std * self.e_dose if self.add_shot else self.gaussian_noise_std + + if self.use_torch: + image = inputs.clone() + noise = torch.clip(torch.normal(mean=mean, std=std, size=inputs.shape), min=0) + image += noise + return image + else: + image = np.array(inputs).copy() + noise = np.clip(np.random.normal(loc=mean, scale=std, size=inputs.shape), a_min=0, a_max=None) + image += noise + return image + def write_logs(self) -> None: if self.log_file is None: return with open(self.log_file, "a") as f: f.write( - f"{self.bkg_weight},{self.bkg_q},{self.e_dose},{self.xshift}," + f"{self.bkg_weight},{self.bkg_q},{self.apply_background_to_label},{self.e_dose},{self.xshift}," f"{self.yshift},{self.exx},{self.eyy},{self.exy}," - f"{self.scale_factor},{self.flip_horizontal},{self.flip_vertical}," + f"{self.gaussian_noise_mu},{self.gaussian_noise_std},{self.scale_factor},{self.flip_horizontal},{self.flip_vertical}," f"{self.rotation_angle},{self.blur_sigma},{self.salt_and_pepper}," f"{self._rng_seed}\n" )