Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ wget>=3.2
ty
ruff
beautifulsoup4
h5py
cdflib
netcdf4
150 changes: 113 additions & 37 deletions swvo/io/RBMDataSet/RBMDataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from __future__ import annotations

import datetime as dt
import logging
import warnings
from datetime import timedelta, timezone
from pathlib import Path
from typing import Any, Literal, cast
from typing import Any, Literal, Optional, cast

import distance
import numpy as np
Expand Down Expand Up @@ -41,10 +43,14 @@
join_var,
load_file_any_format,
matlab2python,
read_all_datasets_cdf,
read_all_datasets_h5,
read_all_datasets_netcdf,
)
from swvo.io.utils import enforce_utc_timezone

logger = logging.getLogger(__name__)


class RBMDataSet:
"""RBMDataSet class supporting .mat, .pickle, and .nc file formats.
Expand All @@ -67,14 +73,16 @@ class RBMDataSet:
Start time for file-based loading.
end_time : dt.datetime, optional
End time for file-based loading.
folder_path : Path, optional
folder_path : Optional[Path | str]
Base folder path for file-based loading.
preferred_extension : Literal["mat", "pickle", "nc"], optional
Preferred file extension for file-based loading. Default is "pickle".
preferred_extension : Literal["mat", "pickle", "nc", "cdf", "h5"], optional
Preferred file extension for file-based loading. Default is "nc".
verbose : bool, optional
Whether to print verbose output. Default is True.
Whether to log verbose output. Default is True.
enable_dict_loading : bool, optional
Enable dictionary-based loading even in file mode. Default is False.
dataorg: bool, optional
Whether to use the new the files saved using DataOrgStrategy for file loading. Default is False.

Attributes
----------
Expand Down Expand Up @@ -102,7 +110,7 @@ class RBMDataSet:

"""

_preferred_ext: Literal["mat", "pickle", "nc"]
_preferred_ext: Literal["mat", "pickle", "nc", "cdf", "h5"]

datetime: list[dt.datetime]
time: NDArray[np.float64]
Expand Down Expand Up @@ -133,11 +141,12 @@ def __init__(
mfm: MfmLike,
start_time: dt.datetime | None = None,
end_time: dt.datetime | None = None,
folder_path: Path | None = None,
preferred_extension: Literal["mat", "pickle", "nc"] = "nc",
folder_path: Optional[Path | str] = None,
preferred_extension: Literal["mat", "pickle", "nc", "cdf", "h5"] = "nc",
*,
verbose: bool = True,
enable_dict_loading: bool = False,
dataorg: bool = False,
) -> None:
self.possible_variables: list[str] = list(VariableLiteral.__args__)

Expand All @@ -156,9 +165,21 @@ def __init__(
if isinstance(mfm, str):
mfm = MfmEnum[mfm.upper()]

if preferred_extension == "pickle":
warnings.warn(
"The '.pickle' file format is deprecated and will be removed in a future release",
FutureWarning,
stacklevel=2,
)
# Validate preferred_extension
if preferred_extension not in ("mat", "pickle", "nc"):
msg = f"preferred_extension must be 'mat', 'pickle', or 'nc', got '{preferred_extension}'"
if preferred_extension not in ("mat", "pickle", "nc", "cdf", "h5"):
msg = f"preferred_extension must be 'mat', 'pickle', 'nc', 'cdf', or 'h5', got '{preferred_extension}'"
raise ValueError(msg)
if dataorg and preferred_extension in ("nc", "cdf", "h5"):
msg = "dataorg = True is only supported with 'mat' or 'pickle' extensions"
raise ValueError(msg)
if not dataorg and preferred_extension == "pickle":
msg = "preferred_extension='pickle' is only supported with dataorg=True"
raise ValueError(msg)
Comment thread
sahiljhawar marked this conversation as resolved.

# Store the original satellite enum for properties and other attributes
Expand Down Expand Up @@ -186,13 +207,14 @@ def __init__(
self._folder_path = Path(folder_path)
self._folder_type = self._satellite.folder_type
self._file_path_stem = self._create_file_path_stem()
self._is_nc_dataset = self._check_if_nc_dataset()
self._is_dataorg_dataset = dataorg
self._is_monthly_dataset = self._check_if_monthly_dataset()
self._file_name_stem = self._create_file_name_stem()
self._file_cadence = self._satellite.file_cadence
self._date_of_files = self._create_date_list()
self._file_loading_mode = True
self._enable_dict_loading = enable_dict_loading
self._netcdf_dataset_cache: dict[Path, dict[str, Any]] = {}
self._monthly_dataset_cache: dict[Path, dict[str, Any]] = {}

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._satellite}, {self._instrument}, {self._mfm})"
Expand Down Expand Up @@ -234,16 +256,17 @@ def __getattr__(self, name: str) -> NDArray[np.float64]:
return getattr(self, name)

if not self._file_loading_mode and name in self.possible_variables:
raise AttributeError(
f"Attribute '{name}' exists in `VariableLiteral` but has not been set. "
"Call `update_from_dict()` before accessing it."
)
msg = f"Attribute '{name}' exists in `VariableLiteral` but has not been set. Call `update_from_dict()` before accessing it."
logger.error(msg)
raise AttributeError(msg)

if levenstein_info["min_distance"] <= 2:
msg = f"{self.__class__.__name__} object has no attribute {name}. Maybe you meant {levenstein_info['var_name']}?"
elif name == "custom":
msg = f"{self.__class__.__name__} object might have custom variables. However, to access them, you first have to load any of the standard variables to trigger the loading process. After that, custom variables can be accessed via `.custom['custom_var_name']`."
else:
msg = f"{self.__class__.__name__} object has no attribute {name}"

logger.error(msg)
raise AttributeError(msg)

def load(self, name_or_var: str | VariableEnum) -> None:
Expand Down Expand Up @@ -328,16 +351,23 @@ def update_from_dict(
def get_var(self, var: VariableEnum) -> NDArray[np.float64]:
return getattr(self, var.var_name)

def _check_if_nc_dataset(self) -> bool:
def _check_if_monthly_dataset(self) -> bool:
does_processed_mat_files_folder_exist = (self._file_path_stem / "Processed_Mat_Files").exists()

if self._is_dataorg_dataset and self._preferred_ext in ["mat", "pickle"]:
if not does_processed_mat_files_folder_exist:
logger.warning("`dataorg` is set to True but Processed_Mat_Files does not exist. ")
return False
if not self._is_dataorg_dataset and self._preferred_ext in ["mat", "h5", "nc", "cdf"]:
return True

if does_processed_mat_files_folder_exist and self._preferred_ext in ["mat", "pickle"]:
return False
elif does_processed_mat_files_folder_exist and self._preferred_ext == "nc":
# if any .nc files are stored in the file_path_stem, we switch to nc mode
return next(self._file_path_stem.glob("*.nc"), None) is not None
elif does_processed_mat_files_folder_exist and self._preferred_ext in ["nc", "cdf", "h5"]:
# if any .nc files are stored in the file_path_stem, we switch to non dataorg mode
return next(self._file_path_stem.glob(f"*.{self._preferred_ext}"), None) is not None
else:
# if the Processed_Mat_Files folder does not exist, it is safe to assume nc mode
# if the Processed_Mat_Files folder does not exist, it is safe to assume non dataorg mode
return True
Comment thread
sahiljhawar marked this conversation as resolved.

def _create_date_list(self) -> list[dt.datetime]:
Expand Down Expand Up @@ -420,39 +450,55 @@ def _load_variable(self, var: Variable | VariableEnum) -> None:
date_str = f"{start_month.strftime('%Y%m%d')}to{next_month.strftime('%Y%m%d')}"

# 3. Handle File Pathing & Loading based on format
if self._is_nc_dataset:
file_name = f"{self._file_name_stem}{date_str}_{self._mfm.mfm_name}.nc"
if self._is_monthly_dataset:
file_name = f"{self._file_name_stem}{date_str}_{self._mfm.mfm_name}.{self._preferred_ext}"
full_file_path = self._file_path_stem / file_name
file_content = self._get_cached_datasets_netcdf(full_file_path)
if not full_file_path.exists():
logger.warning(f"File not found: {full_file_path}")
file_content = {}
elif self._preferred_ext == "nc":
file_content = self._get_cached_datasets_netcdf(full_file_path)
elif self._preferred_ext == "h5":
file_content = self._get_cached_datasets_h5(full_file_path)
elif self._preferred_ext == "cdf":
file_content = self._get_cached_datasets_cdf(full_file_path)
else:
if self._verbose:
logger.info(f"Loading {full_file_path}")
file_content = load_file_any_format(full_file_path)
else:
file_name_no_format = f"{self._file_name_stem}{date_str}_{var.mat_file_prefix}"
if var.mat_has_B:
file_name_no_format += f"_n4_4_{self._mfm.mfm_name}"
file_name_no_format += "_ver4"

full_file_path = get_file_path_any_format(
self._file_path_stem, file_name_no_format, self._preferred_ext, self._is_nc_dataset
self._file_path_stem, file_name_no_format, self._preferred_ext, self._is_monthly_dataset
)
Comment thread
sahiljhawar marked this conversation as resolved.
if full_file_path is None:
print(f"File not found: {file_name_no_format}")
logger.warning(f"File not found: {file_name_no_format}")
continue

if self._verbose:
print(f"\tLoading {full_file_path}")
logger.info(f"Loading {full_file_path}")
file_content = load_file_any_format(full_file_path)

if not file_content:
continue

# 4. Process Datetimes
raw_times = file_content["time"]
if self._is_nc_dataset:
# NetCDF timestamp logic
if self._is_monthly_dataset and self._preferred_ext in ["nc", "h5", "cdf", "mat"]:
# NetCDF/HDF5/CDF timestamp logic
if self._preferred_ext == "mat":
logging.info(
"Assuming time variable in .mat files is in POSIX timestamp format (seconds since 1970-01-01T00:00:00Z)"
)
datetimes = np.asarray(
[dt.datetime.fromtimestamp(t.astype(np.int64), tz=dt.timezone.utc) for t in raw_times]
)
else:
# Matlab logic
# Matlab/pickle logic
datetimes = np.asarray([matlab2python(t) for t in raw_times])

file_content["datetime"] = datetimes
Expand Down Expand Up @@ -486,8 +532,18 @@ def _load_variable(self, var: Variable | VariableEnum) -> None:
for var_name in var_names_stored:
val = list(loaded_var_arrs[var_name]) if var_name == "datetime" else loaded_var_arrs[var_name]

if self._is_nc_dataset:
# NetCDF name mapping logic
if self._is_monthly_dataset and self._preferred_ext in ["nc", "h5", "cdf", "mat"]:
if var_name.startswith("custom/"):
custom_key = var_name.split("/", 1)[1]
if custom_key:
custom_dict = self.__dict__.get("custom")
if not isinstance(custom_dict, dict):
custom_dict = {}
custom_dict[custom_key] = val
setattr(self, "custom", custom_dict)
continue

# NetCDF/HDF5/CDF name mapping logic
rbm_names = self._get_rbm_name_for_nc(var_name, self._mfm.mfm_name) # type: ignore
if rbm_names:
for name in rbm_names if isinstance(rbm_names, list) else [rbm_names]:
Expand All @@ -498,12 +554,32 @@ def _load_variable(self, var: Variable | VariableEnum) -> None:
def _get_cached_datasets_netcdf(self, file_path: Path) -> dict[str, Any]:
"""Return cached parsed NetCDF content for a monthly file."""
file_path = Path(file_path)
if file_path not in self._netcdf_dataset_cache:
if file_path not in self._monthly_dataset_cache:
if self._verbose:
logger.info(f"Loading netCDF {file_path}")

self._monthly_dataset_cache[file_path] = read_all_datasets_netcdf(file_path)
return self._monthly_dataset_cache[file_path]

def _get_cached_datasets_h5(self, file_path: Path) -> dict[str, Any]:
"""Return cached parsed HDF5 content for a monthly file."""
file_path = Path(file_path)
if file_path not in self._monthly_dataset_cache:
if self._verbose:
logger.info(f"Loading H5 {file_path}")

self._monthly_dataset_cache[file_path] = read_all_datasets_h5(file_path)
return self._monthly_dataset_cache[file_path]

def _get_cached_datasets_cdf(self, file_path: Path) -> dict[str, Any]:
"""Return cached parsed CDF content for a monthly file."""
file_path = Path(file_path)
if file_path not in self._monthly_dataset_cache:
if self._verbose:
print(f"\tLoading {file_path}")
logger.info(f"Loading CDF {file_path}")

self._netcdf_dataset_cache[file_path] = read_all_datasets_netcdf(file_path)
return self._netcdf_dataset_cache[file_path]
self._monthly_dataset_cache[file_path] = read_all_datasets_cdf(file_path)
return self._monthly_dataset_cache[file_path]

@classmethod
def _get_rbm_name_for_nc(
Expand Down
4 changes: 2 additions & 2 deletions swvo/io/RBMDataSet/interp_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def interp_flux(
self.Flux,
self.energy_channels,
self.alpha_eq_model,
targets,
targets, # ty:ignore[invalid-argument-type]
)

with Pool(n_threads) as p:
Expand Down Expand Up @@ -308,7 +308,7 @@ def interp_psd(self: RBMDataSet,
_ = self.PSD; _ = self.InvMu; _ = self.InvK

# parallel over time (same pattern as interp_flux)
func = partial(_interp_psd_parallel, self.PSD, self.InvMu, self.InvK, targets)
func = partial(_interp_psd_parallel, self.PSD, self.InvMu, self.InvK, targets) # ty:ignore[invalid-argument-type]

with Pool(n_threads) as p:
rs = p.map_async(func, range(len(self.time)))
Expand Down
14 changes: 7 additions & 7 deletions swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def create_RBSP_line_data(

for i, instrument in enumerate(instruments):
energy_offsets[i] = np.nanmin(
np.abs(rbm_data[i].energy_channels_no_time - target_en_single),
np.abs(rbm_data[i].energy_channels_no_time - target_en_single), # ty:ignore[unsupported-operator]
axis=None,
)

Expand Down Expand Up @@ -163,7 +163,7 @@ def create_RBSP_line_data(
rbm_data_set_result.line_data_energy = np.empty((len(target_en),)) # ty:ignore[invalid-argument-type, unresolved-attribute]
rbm_data_set_result.line_data_alpha_local = np.empty((len(target_al),)) # ty:ignore[invalid-argument-type, unresolved-attribute]

energy_offsets_relative = energy_offsets / target_en_single
energy_offsets_relative = energy_offsets / target_en_single # ty:ignore[unsupported-operator]
if np.all(np.abs(energy_offsets_relative) > energy_offset_threshold):
raise ValueError(
f"For the given energy target ({target_en_single:.2e} MeV), no suitable energy channel was found for a threshold of {energy_offset_threshold:.02f}!"
Expand All @@ -178,7 +178,7 @@ def create_RBSP_line_data(
)

closest_en_idx = np.nanargmin(
np.abs(rbm_data[min_offset_instrument].energy_channels_no_time - target_en_single)
np.abs(rbm_data[min_offset_instrument].energy_channels_no_time - target_en_single) # ty:ignore[unsupported-operator]
)
rbm_data_set_result.line_data_energy[e] = rbm_data[min_offset_instrument].energy_channels_no_time[
closest_en_idx
Expand All @@ -199,7 +199,7 @@ def create_RBSP_line_data(
else:
rbm_data_set_result.line_data_flux[:, e] = np.squeeze(
rbm_data[min_offset_instrument].interp_flux(
target_en_single,
target_en_single, # ty:ignore[invalid-argument-type]
target_al[e], # ty:ignore[not-subscriptable]
TargetType.TargetPairs,
)
Expand All @@ -208,7 +208,7 @@ def create_RBSP_line_data(
elif target_type == TargetType.TargetMeshGrid:
for a, target_al_single in enumerate(target_al):
closest_al_idx = np.nanargmin(
np.abs(rbm_data[min_offset_instrument].alpha_local_no_time - target_al_single)
np.abs(rbm_data[min_offset_instrument].alpha_local_no_time - target_al_single) # ty:ignore[unsupported-operator]
)
rbm_data_set_result.line_data_alpha_local[a] = rbm_data[min_offset_instrument].alpha_local_no_time[
closest_al_idx
Expand All @@ -221,8 +221,8 @@ def create_RBSP_line_data(
else:
rbm_data_set_result.line_data_flux[:, e, a] = np.squeeze(
rbm_data[min_offset_instrument].interp_flux(
target_en_single,
target_al_single,
target_en_single, # ty:ignore[invalid-argument-type]
target_al_single, # ty:ignore[invalid-argument-type]
TargetType.TargetPairs,
)
)
Expand Down
Loading
Loading