diff --git a/src/quantem/core/datastructures/vector.py b/src/quantem/core/datastructures/vector.py index 9bc513a4..aa2627d2 100644 --- a/src/quantem/core/datastructures/vector.py +++ b/src/quantem/core/datastructures/vector.py @@ -1,1026 +1,1310 @@ -from typing import ( - Any, - List, - Optional, - Tuple, - Union, - cast, - overload, -) +from __future__ import annotations + +import copy +from pathlib import Path +from typing import Any, Literal, Sequence import numpy as np -from numpy.typing import ArrayLike, NDArray +from numpy.typing import NDArray from quantem.core.io.serialize import AutoSerialize from quantem.core.utils.validators import ( validate_fields, validate_num_fields, validate_shape, - validate_vector_data, - validate_vector_data_for_inference, validate_vector_units, ) class Vector(AutoSerialize): - """ - A class for holding vector data with ragged array lengths. This class supports any number of fixed dimensions - (indexed first) followed by a ragged numpy array that can have any number of entries (rows) and columns (fields). - Inherits from AutoSerialize for serialization support. - - Basic Usage: - ----------- - # Create a 2D vector with shape=(4, 3) and 3 named fields - v = Vector.from_shape(shape=(4, 3), fields=['field0', 'field1', 'field2']) - - # Alternative creation with num_fields instead of fields - v = Vector.from_shape(shape=(4, 3), num_fields=3) # Fields will be named field_0, field_1, field_2 - - # Create with custom name and units - v = Vector.from_shape( - shape=(4, 3), - fields=['field0', 'field1', 'field2'], - name='my_vector', - units=['unit0', 'unit1', 'unit2'], - ) - - # Access data at specific indices - data = v[0, 1] # Returns numpy array at position (0,1) - - # Set data at specific indices - v[0, 1] = np.array([[1.0, 2.0, 3.0]]) # Must match num_fields - - # Create a deep copy - v_copy = v.copy() - - Example usage of from_data: - ----------------------------------- - data = [ - np.array([[1, 2], [3, 4]]), - np.array([[5, 6], [7, 8], [9, 10]]) - ] - v = Vector.from_data( - data, - fields=['x', 'y'], - name='my_ragged_vector', - units=['m', 'm'] - ) - - # Or using lists instead of numpy arrays: - data = [ - [[1, 2], [3, 4]], - [[5, 6], [7, 8], [9, 10]], - ] - v = Vector.from_data( - data, - fields=['x', 'y'], - name='my_ragged_vector', - units=['m', 'm'] - ) - - Field Operations: - ---------------- - # Access a specific field - field_data = v['field0'] # Returns a FieldView object - - # Perform operations on a field - v['field0'] += 16 # Add 16 to all field0 values - - # Apply a function to a field - v['field2'] = lambda x: x * 2 # Double all field2 values - - # Get flattened field data - field_flat = v['field0'].flatten() # Returns 1D numpy array - - # Set field data from flattened array - v['field2'].set_flattened(new_values) # Must match total length - - Advanced Operations: - ------------------- - # Complex field calculations - scale = v['field0'].flatten() / (v['field0'].flatten()**2 + v['field1'].flatten()**2) - v['field2'].set_flattened(v['field2'].flatten() * scale) - - # Slicing and assignment - v[2:4, 1] = v[1:3, 1] # Copy data from one region to another - - # Boolean indexing - mask = v['field0'].flatten() > 0 - v['field2'].set_flattened(v['field2'].flatten() * mask) - - # Field management - v.add_fields(('field3', 'field4', 'field5')) # Add new fields - v.remove_fields(('field3', 'field4', 'field5')) # Remove fields - - Direct Data Access: - ------------------ - # Get data with integer indexing - data = v.get_data(0, 1) # Returns numpy array at (0,1) - - # Get data with slice indexing - data = v.get_data(slice(0, 2), 1) # Returns list of arrays for rows 0-1 at column 1 - - # Set data with integer indexing - v.set_data(np.array([[1.0, 2.0, 3.0]]), 0, 1) # Set data at (0,1) - - # Set data with slice indexing - v.set_data([np.array([[1.0, 2.0, 3.0]]), np.array([[4.0, 5.0, 6.0]])], - slice(0, 2), 1) # Set data for rows 0-1 at column 1 - - Notes: + """Ragged cell data on a fixed grid. + + A ``Vector`` has two independent axes of structure: + - fixed-grid dimensions given by ``shape`` + - ragged rows stored inside each fixed-grid cell + + Each ragged row has one value per named field, so each cell behaves like a + small 2D array with shape ``(n_rows, num_fields)``, where ``n_rows`` may + vary from cell to cell. + + Parameters + ---------- + shape : tuple of int + Fixed-grid shape. + fields : sequence of str + Field names in column order. + units : sequence of str, optional + Units corresponding to ``fields``. If omitted, units default to + ``"none"`` for all fields. + name : str, optional + Descriptive name for the Vector. + metadata : dict, optional + Additional user metadata. + + Notes ----- - - All numpy arrays stored in the vector must have the same number of columns (fields) - - Field names must be unique - - Slicing operations return new Vector instances - - Field operations are performed in-place - - Units are stored for each field and can be accessed via the units attribute - - The name attribute can be used to identify the vector in a larger context + The public API keeps fixed-grid indexing and field selection separate: + - use ``[]`` for fixed-grid indexing + - use ``select_fields(...)`` for field selection + + Fixed-grid indexing always returns a ``Vector``. A 0D selection exposes its + underlying cell array through ``.array``. Multi-cell selections can be + concatenated with ``flatten()``. + + The internal representation is compact: + - ``_state["data"]`` stores all ragged rows in one numeric 2D array + - ``_state["cell_starts"]`` stores the start offset for each cell + - ``_state["cell_lengths"]`` stores the row count for each cell + + A ``Vector`` selection is a write-through view over shared storage. Views + track only the selected fixed-grid shape, selected cell indices, and selected + field names. + + Examples + -------- + Create a Vector and assign one cell: + + >>> import numpy as np + >>> v = Vector.from_shape((2, 2), fields=("kx", "ky", "intensity")) + >>> v[0, 0] = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> v[0, 0].array.shape + (2, 3) + + Select fields and apply in-place arithmetic: + + >>> kx = v.select_fields("kx") + >>> kx += 16 + >>> kx.flatten().shape + (2, 1) + + Apply a rowwise transform with ``flatten()`` and ``set_flattened()``: + + >>> kx = v.select_fields("kx") + >>> ky = v.select_fields("ky") + >>> kx.set_flattened( + ... np.where( + ... ((kx.flatten() - 16) ** 2 + (ky.flatten() - 16) ** 2) < 12, + ... 10, + ... kx.flatten(), + ... ) + ... ) """ + __array_priority__ = 1000 _token = object() + # ------------------------------------------------------------------ # + # Construction + # ------------------------------------------------------------------ # + def __init__( self, - shape: Tuple[int, ...], - fields: List[str], - units: List[str], - name: str, - metadata: dict = {}, + shape: tuple[int, ...], + fields: Sequence[str], + units: Sequence[str] | None = None, + name: str | None = None, + metadata: dict[str, Any] | None = None, _token: object | None = None, ) -> None: if _token is not self._token: - raise RuntimeError("Use Vector.from_shape() or Vector.from_data() to instantiate.") + raise RuntimeError( + "Use Vector.from_shape() or Vector.from_data() to instantiate this class." + ) + root_shape = validate_shape(shape) + root_fields = validate_fields(list(fields)) + root_units = validate_vector_units( + list(units) if units is not None else None, + len(root_fields), + ) - self.shape = shape - self.fields = fields - self.units = units - self.name = name - self._data = nested_list(self.shape, fill=None) - self._metadata = metadata + self._state = { + "shape": root_shape, + "fields": list(root_fields), + "units": list(root_units), + "name": name or f"{len(root_shape)}d ragged array", + "metadata": dict(metadata or {}), + "data": np.empty((0, len(root_fields)), dtype=float), + "cell_starts": np.zeros(_cell_count(root_shape), dtype=np.int64), + "cell_lengths": np.zeros(_cell_count(root_shape), dtype=np.int64), + } + self._selection_shape = root_shape + self._selection_indices: NDArray[np.int64] | None = None + self._selected_fields: tuple[str, ...] | None = None @classmethod - def from_shape( + def _from_view( cls, - shape: Tuple[int, ...], - num_fields: Optional[int] = None, - fields: Optional[List[str]] = None, - units: Optional[List[str]] = None, - name: Optional[str] = None, + state: dict[str, Any], + selection_shape: tuple[int, ...], + selection_indices: NDArray[np.int64] | None, + selected_fields: tuple[str, ...] | None, ) -> "Vector": - """ - Factory method to create a Vector with the specified shape and fields. - - Parameters - ---------- - shape : Tuple[int, ...] - The shape of the vector (dimensions) - num_fields : Optional[int] - Number of fields in the vector - name : Optional[str] - Name of the vector - fields : Optional[List[str]] - List of field names - units : Optional[List[str]] - List of units for each field - - Returns - ------- - Vector - A new Vector instance - """ - validated_shape = validate_shape(shape) - ndim = len(validated_shape) - - if fields is not None: - validated_fields = validate_fields(fields) - validated_num_fields = len(validated_fields) - if num_fields is not None and validated_num_fields != num_fields: - raise ValueError( - f"num_fields ({num_fields}) does not match length of fields ({validated_num_fields})" - ) - elif num_fields is not None: - validated_num_fields = validate_num_fields(num_fields) - validated_fields = [f"field_{i}" for i in range(validated_num_fields)] - else: - raise ValueError("Must specify either 'fields' or 'num_fields'.") - - validated_units = validate_vector_units(units, validated_num_fields) - name = name or f"{ndim}d ragged array" + """Build a view that shares backing storage with another Vector.""" + obj = cls.__new__(cls) + obj._state = state + obj._selection_shape = selection_shape + obj._selection_indices = ( + None if selection_indices is None else selection_indices.astype(np.int64, copy=False) + ) + obj._selected_fields = selected_fields + return obj + @classmethod + def from_shape( + cls, + shape: tuple[int, ...], + num_fields: int | None = None, + fields: Sequence[str] | None = None, + units: Sequence[str] | None = None, + name: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> "Vector": + """Create an empty Vector with the given fixed-grid shape and fields.""" + fields = _resolve_fields(fields, num_fields, None) return cls( - shape=validated_shape, - fields=validated_fields, - units=validated_units, + shape=shape, + fields=fields, + units=units, name=name, + metadata=metadata, _token=cls._token, ) @classmethod def from_data( cls, - data: List[Any], - num_fields: Optional[int] = None, - fields: Optional[List[str]] = None, - units: Optional[List[str]] = None, - name: Optional[str] = None, + data: Sequence[Any], + num_fields: int | None = None, + fields: Sequence[str] | None = None, + units: Sequence[str] | None = None, + name: str | None = None, + metadata: dict[str, Any] | None = None, ) -> "Vector": - """ - Factory method to create a Vector from a list of - ragged lists or ragged numpy arrays. + """Create a Vector from nested fixed-grid data. - Parameters - ---------- - data : List[Any] - A list of ragged lists containing the vector data. - Each element should be a numpy array with shape (n, num_fields). - num_fields : Optional[int] - Number of fields in the vector. If not provided, it will be inferred from the data. - fields : Optional[List[str]] - List of field names - units : Optional[List[str]] - List of units for each field - name : Optional[str] - Name of the vector - - Returns - ------- - Vector - A new Vector instance with the provided data - - Raises - ------ - ValueError - If the data structure is invalid or inconsistent - TypeError - If the data contains invalid types + The outer nesting defines the fixed-grid shape. Each leaf must coerce to a + 2D cell array with consistent field count across all cells. """ - inferred_shape, inferred_num_fields = validate_vector_data_for_inference(data) - - final_num_fields = num_fields or inferred_num_fields - if num_fields is not None and num_fields != inferred_num_fields: - raise ValueError( - f"Provided num_fields ({num_fields}) does not match inferred ({inferred_num_fields})." - ) - - vector = cls.from_shape( - shape=inferred_shape, - num_fields=final_num_fields, - fields=fields, + if not isinstance(data, (list, tuple)): + raise TypeError(f"Data must be a list or tuple, got {type(data)}") + root_shape, cell_arrays = _flatten_fixed_grid(data) if len(data) > 0 else ((0,), []) + inferred_counts = {array.shape[1] for array in cell_arrays} + if len(inferred_counts) > 1: + raise ValueError("All cell arrays must have the same number of fields.") + inferred_fields = cell_arrays[0].shape[1] if cell_arrays else 0 + + vector = cls( + shape=root_shape, + fields=_resolve_fields(fields, num_fields, inferred_fields), units=units, name=name, + metadata=metadata, + _token=cls._token, ) - - # Now fully validate and set the data - vector.data = data + vector._replace_cells(np.arange(len(cell_arrays), dtype=np.int64), cell_arrays) return vector - def get_data( - self, *indices: Union[int, slice, List[int], np.ndarray[Any, np.dtype[Any]]] - ) -> Union[NDArray, List[NDArray]]: - """ - Get data at specified indices. - - Parameters: - ----------- - *indices : Union[int, slice, List[int], np.ndarray] - Indices to access. Must match the number of dimensions in the vector. - Supports fancy indexing with lists or numpy arrays. - - Returns: - -------- - numpy.ndarray or list - The data at the specified indices. - - Raises: - ------- - IndexError - If indices are out of bounds. - ValueError - If the number of indices does not match the vector dimensions. - """ - if len(indices) != len(self._shape): - raise ValueError(f"Expected {len(self._shape)} indices, got {len(indices)}") - - # Handle fancy indexing and slicing - def get_indices(dim_idx: Any, dim_size: int) -> np.ndarray: - if isinstance(dim_idx, slice): - start, stop, step = dim_idx.indices(dim_size) - return np.arange(start, stop, step) - elif isinstance(dim_idx, (np.ndarray, list)): - idx = np.asarray(dim_idx) - if np.any((idx < 0) | (idx >= dim_size)): - raise IndexError(f"Index out of bounds for axis with size {dim_size}") - return idx - elif isinstance(dim_idx, (int, np.integer)): - if dim_idx < 0 or dim_idx >= dim_size: - raise IndexError( - f"Index {dim_idx} out of bounds for axis with size {dim_size}" - ) - return np.array([dim_idx]) - return np.arange(dim_size) - - # Get indices for each dimension - indices_arrays = [get_indices(i, s) for i, s in zip(indices, self._shape)] - - # If all indices are single integers, return a single array - if all(len(i) == 1 for i in indices_arrays): - ref = self._data - for idx in (i[0] for i in indices_arrays): - ref = ref[idx] - return ref - - # Create result structure for fancy indexing - result = [] - for idx in np.ndindex(*[len(i) for i in indices_arrays]): - src_idx = tuple(ind[i] for ind, i in zip(indices_arrays, idx)) - result.append(self._data[src_idx[0]][src_idx[1]]) + # ------------------------------------------------------------------ # + # Identity properties + # ------------------------------------------------------------------ # - return result + @property + def name(self) -> str: + """Human-readable Vector name.""" + return self._state["name"] - def set_data( - self, - value: Union[NDArray, List[NDArray]], - *indices: Union[int, slice, List[int], np.ndarray[Any, np.dtype[Any]]], - ) -> None: - """ - Set data at specified indices. + @name.setter + def name(self, value: str) -> None: + self._state["name"] = str(value) - Parameters - ---------- - value : Union[NDArray, List[NDArray]] - The numpy array(s) to set at the specified indices. Must have shape (_, num_fields). - For fancy indexing, can be a list of arrays. - *indices : Union[int, slice, List[int], np.ndarray] - Indices to set data at. Must match the number of dimensions in the vector. - Supports fancy indexing with lists or numpy arrays. - - Raises - ------ - IndexError - If indices are out of bounds. - ValueError - If the number of indices does not match the vector dimensions, - or if the value shape doesn't match the expected shape. - TypeError - If the value is not a numpy array or list of numpy arrays. - """ - if len(indices) != len(self._shape): - raise ValueError(f"Expected {len(self._shape)} indices, got {len(indices)}") - - # Handle fancy indexing and slicing - def get_indices(dim_idx: Any, dim_size: int) -> np.ndarray: - if isinstance(dim_idx, slice): - start, stop, step = dim_idx.indices(dim_size) - return np.arange(start, stop, step) - elif isinstance(dim_idx, (np.ndarray, list)): - idx = np.asarray(dim_idx) - if np.any((idx < 0) | (idx >= dim_size)): - raise IndexError(f"Index out of bounds for axis with size {dim_size}") - return idx - elif isinstance(dim_idx, (int, np.integer)): - if dim_idx < 0 or dim_idx >= dim_size: - raise IndexError( - f"Index {dim_idx} out of bounds for axis with size {dim_size}" - ) - return np.array([dim_idx]) - return np.arange(dim_size) - - # Get indices for each dimension - indices_arrays = [get_indices(i, s) for i, s in zip(indices, self._shape)] - - # If all indices are single integers, handle as single value - if all(len(i) == 1 for i in indices_arrays): - if not isinstance(value, np.ndarray): - raise TypeError(f"Value must be a numpy array, got {type(value).__name__}") - if value.ndim != 2 or value.shape[1] != self.num_fields: - raise ValueError( - f"Expected a numpy array with shape (_, {self.num_fields}), got {value.shape}" - ) - ref = self._data - for idx in (i[0] for i in indices_arrays[:-1]): - ref = ref[idx] - ref[indices_arrays[-1][0]] = value - return + @property + def metadata(self) -> dict[str, Any]: + """Mutable metadata dictionary shared by all views.""" + return self._state["metadata"] - # Handle fancy indexing - if not isinstance(value, list): - raise TypeError("For fancy indexing, value must be a list of numpy arrays") - - # Validate and set values - for idx in np.ndindex(*[len(i) for i in indices_arrays]): - src_idx = tuple(ind[i] for ind, i in zip(indices_arrays, idx)) - if not isinstance(value[idx[0]], np.ndarray): - raise TypeError(f"Expected numpy array, got {type(value[idx[0]]).__name__}") - if value[idx[0]].ndim != 2 or value[idx[0]].shape[1] != self.num_fields: - raise ValueError( - f"Expected array with shape (_, {self.num_fields}), got {value[idx[0]].shape}" - ) - ref = self._data - for i in src_idx[:-1]: - ref = ref[i] - ref[src_idx[-1]] = value[idx[0]] - - @overload - def __getitem__(self, idx: str) -> "_FieldView": ... - @overload - def __getitem__( - self, - idx: Union[Tuple[Union[int, slice, List[int]], ...], int, slice, List[int]], - ) -> Union[NDArray, "Vector"]: ... + # ------------------------------------------------------------------ # + # Shape & structure properties + # ------------------------------------------------------------------ # - def __getitem__( - self, - idx: Union[str, Tuple[Union[int, slice, List[int]], ...], int, slice, List[int]], - ) -> Union["_FieldView", NDArray, "Vector"]: - """Get data or a view of the vector at specified indices.""" - if isinstance(idx, str): - if idx not in self._fields: - raise KeyError(f"Field '{idx}' not found.") - return _FieldView(self, idx) - - # Normalize idx to tuple - normalized: Tuple[Any, ...] = (idx,) if not isinstance(idx, tuple) else idx - - # Convert lists/arrays to ndarray - idx_converted: Tuple[Union[int, slice, np.ndarray[Any, np.dtype[Any]]], ...] = tuple( - np.asarray(i) if isinstance(i, (list, np.ndarray)) else i for i in normalized - ) + @property + def shape(self) -> tuple[int, ...]: + """Return the fixed-grid shape of this selection.""" + return self._selection_shape - # Check if we should return a numpy array (all indices are integers) - return_np = all(isinstance(i, (int, np.integer)) for i in idx_converted[: len(self.shape)]) - if len(idx_converted) < len(self.shape): - return_np = False - - if return_np: - view = self._data - for i in idx_converted: - view = view[i] - return cast(NDArray[Any], view) - - # Handle fancy indexing and slicing - def get_indices(dim_idx: Any, dim_size: int) -> np.ndarray: - if isinstance(dim_idx, slice): - start, stop, step = dim_idx.indices(dim_size) - return np.arange(start, stop, step) - elif isinstance(dim_idx, (np.ndarray, list)): - return np.asarray(dim_idx) - elif isinstance(dim_idx, (int, np.integer)): - return np.array([dim_idx]) - return np.arange(dim_size) - - # Get indices for each dimension - full_idx = list(idx_converted) + [slice(None)] * (len(self.shape) - len(idx_converted)) - indices = [get_indices(i, s) for i, s in zip(full_idx, self.shape)] - - # Create new shape and data - new_shape = [len(i) for i in indices] - new_data = [[None] * new_shape[-1] for _ in range(new_shape[0])] - - # Fill the new data structure - for out_idx in np.ndindex(*new_shape): - src_idx = tuple(ind[i] for ind, i in zip(indices, out_idx)) - new_data[out_idx[0]][out_idx[1]] = self._data[src_idx[0]][src_idx[1]] - - # Create new Vector - vector_new = Vector.from_shape( - shape=tuple(new_shape), - num_fields=self.num_fields, - name=self.name + "[view]", - fields=self.fields, - units=self.units, - ) - vector_new._data = new_data - return vector_new + @property + def fields(self) -> list[str]: + """Return selected field names in column order.""" + if self._selected_fields is None: + return list(self._state["fields"]) + return list(self._selected_fields) - def __setitem__( - self, - idx: Union[Tuple[Union[int, slice, List[int]], ...], int, slice, List[int], str], - value: Union[NDArray, List[NDArray]], - ) -> None: - """Set data at specified indices.""" - if isinstance(idx, str): - if idx not in self._fields: - raise KeyError(f"Field '{idx}' not found.") - field_view = _FieldView(self, idx) - field_view.set_flattened(value) - return + @property + def units(self) -> list[str]: + """Return units for the selected fields.""" + lookup = dict(zip(self._state["fields"], self._state["units"])) + return [lookup[field] for field in self.fields] - # Normalize idx to tuple - normalized: Tuple[Any, ...] = (idx,) if not isinstance(idx, tuple) else idx + @property + def num_fields(self) -> int: + """Return the number of selected fields.""" + return len(self.fields) - # Convert lists/arrays to ndarray - idx_converted: Tuple[Union[int, slice, np.ndarray[Any, np.dtype[Any]]], ...] = tuple( - np.asarray(i) if isinstance(i, (list, np.ndarray)) else i for i in normalized - ) + @property + def num_cells(self) -> int: + """Return the number of fixed-grid cells in the current selection.""" + return int(self._selected_cell_indices().size) - # Check if we're doing slice‐ or array‐based (multi‐cell) indexing - has_fancy = any( - isinstance(i, slice) or (isinstance(i, np.ndarray) and i.size > 1) - for i in idx_converted[: len(self.shape)] - ) + @property + def total_rows(self) -> int: + """Return the total ragged-row count in the current selection.""" + return int(self._state["cell_lengths"][self._selected_cell_indices()].sum()) - if has_fancy: - # If user passed a Vector, extract its cell arrays - if isinstance(value, Vector): - - def _flatten_cells(data): - if isinstance(data, np.ndarray): - return [data] - out = [] - for sub in data: - out.extend(_flatten_cells(sub)) - return out - - value = _flatten_cells(value._data) - - # For fancy indexing, value should be a list of arrays - if not isinstance(value, list): - raise TypeError( - "For fancy/slice indexing, value must be a list of numpy arrays or a Vector" - ) - - # Get indices for each dimension - def get_indices(dim_idx: Any, dim_size: int) -> np.ndarray: - if isinstance(dim_idx, slice): - start, stop, step = dim_idx.indices(dim_size) - return np.arange(start, stop, step) - elif isinstance(dim_idx, (np.ndarray, list)): - idx = np.asarray(dim_idx) - if np.any((idx < 0) | (idx >= dim_size)): - raise IndexError(f"Index out of bounds for axis with size {dim_size}") - return idx - elif isinstance(dim_idx, (int, np.integer)): - if dim_idx < 0 or dim_idx >= dim_size: - raise IndexError(f"Index out of bounds for axis with size {dim_size}") - return np.array([dim_idx]) - return np.arange(dim_size) - - indices_arrays = [get_indices(i, s) for i, s in zip(idx_converted, self._shape)] - total_indices = np.prod([len(i) for i in indices_arrays]) - - if len(value) != total_indices: - raise ValueError(f"Expected {total_indices} arrays, got {len(value)}") - - # Validate and set values - for array_idx, idx in enumerate(np.ndindex(*[len(i) for i in indices_arrays])): - src_idx = tuple(ind[i] for ind, i in zip(indices_arrays, idx)) - if not isinstance(value[array_idx], np.ndarray): - raise TypeError(f"Expected numpy array, got {type(value[array_idx]).__name__}") - if value[array_idx].ndim != 2 or value[array_idx].shape[1] != self.num_fields: - raise ValueError( - f"Expected array with shape (_, {self.num_fields}), got {value[array_idx].shape}" - ) - ref = self._data - for i in src_idx[:-1]: - ref = ref[i] - ref[src_idx[-1]] = value[array_idx] - else: - # For single value assignment - if not isinstance(value, np.ndarray): - raise TypeError(f"Value must be a numpy array, got {type(value).__name__}") - if value.ndim != 2 or value.shape[1] != self.num_fields: - raise ValueError( - f"Expected a numpy array with shape (_, {self.num_fields}), got {value.shape}" - ) - ref = self._data - for i in idx_converted[:-1]: - ref = ref[i] - ref[idx_converted[-1]] = value - - def add_fields(self, new_fields: Union[str, List[str]]) -> None: + @property + def dtype(self) -> np.dtype[Any]: + """Return the NumPy dtype of the backing row buffer.""" + return self._state["data"].dtype + + # ------------------------------------------------------------------ # + # Data access + # ------------------------------------------------------------------ # + + @property + def array(self) -> NDArray[Any]: + """Return the selected cell as a NumPy array. + + This is only valid for 0D selections. Single-field and contiguous + multi-field selections return writable views into the backing storage. + Non-contiguous multi-field selections return a copy because NumPy cannot + expose a writable column-subset view for that layout. """ - Add new fields to the vector. + if self.shape != (): + raise ValueError(".array is only valid when the selection contains exactly one cell.") + cell = self._cell_matrix(self._selected_cell_indices()[0]) + cols = self._field_indices() + if cols.size == self._full_num_fields: + return cell + if cols.size == 1: + col = int(cols[0]) + return cell[:, col : col + 1] + if _is_contiguous(cols): + return cell[:, int(cols[0]) : int(cols[-1]) + 1] + return cell[:, cols].copy() + + def flatten(self) -> NDArray[Any]: + """Concatenate selected cells in row-major order. + + Returns a 2D array with shape ``(total_rows, num_fields)`` even for + single-field selections. + """ + arrays = [ + self._selected_cell_matrix(index) + for index in self._selected_cell_indices() + if self._cell_row_count(index) > 0 + ] + if arrays: + return np.vstack(arrays) - Parameters - ---------- - new_fields : Union[str, List[str]] - Field name(s) to add. Must be unique and not already present. + dtype = self._state["data"].dtype if self._state["data"].ndim == 2 else float + return np.empty((0, self.num_fields), dtype=dtype) + + def row_counts(self) -> list[int]: + """Return per-cell row counts in the current selection order.""" + return [self._cell_row_count(int(index)) for index in self._selected_cell_indices()] - Raises - ------ - ValueError - If any field name already exists or if there are duplicates + # ------------------------------------------------------------------ # + # Field management + # ------------------------------------------------------------------ # + + def select_fields(self, *field_names: str | Sequence[str]) -> "Vector": + """Return a view containing only the requested fields. + + Accepted forms: + - ``select_fields("kx")`` + - ``select_fields("kx", "ky")`` + - ``select_fields(["kx", "ky"])`` """ - if isinstance(new_fields, str): - new_fields = [new_fields] + if not field_names: + raise ValueError("At least one field name is required.") + if len(field_names) == 1 and not isinstance(field_names[0], str): + selected = _normalize_field_names(field_names[0]) + elif not all(isinstance(n, str) for n in field_names): + raise TypeError( + "select_fields(...) expects field names as strings or one sequence of strings." + ) else: - new_fields = list(new_fields) + selected = _normalize_field_names(field_names) # type: ignore[arg-type] + available = set(self.fields) + missing = [field for field in selected if field not in available] + if missing: + raise KeyError(f"Unknown field(s): {missing}") + + selected_fields = None if selected == tuple(self._state["fields"]) else selected + return self._from_view( + self._state, + self.shape, + self._selection_indices, + selected_fields, + ) - if any(name in self._fields for name in new_fields): + def add_fields( + self, + names: str | Sequence[str], + values: Any | None = None, + units: str | Sequence[str] | None = None, + ) -> None: + """Add one or more new fields to the full Vector schema.""" + self._require_full_field_view("add_fields") + new_fields = _normalize_field_names(names) + if any(field in self._state["fields"] for field in new_fields): raise ValueError("One or more new field names already exist.") - if len(set(new_fields)) != len(new_fields): - raise ValueError("Duplicate field names in input are not allowed.") + new_units = _normalize_units(units, len(new_fields)) + old_fields = list(self._state["fields"]) + self._state["fields"].extend(new_fields) + self._state["units"].extend(new_units) + self._expand_storage(len(new_fields)) - self._fields = list(self._fields) + list(new_fields) - self._units = list(self._units) + ["none"] * len(new_fields) + if values is None: + return - def expand_array(arr: Any) -> Any: - if isinstance(arr, np.ndarray): - if arr.shape[1] != self.num_fields - len(new_fields): - raise ValueError( - f"Expected arrays with {self.num_fields - len(new_fields)} fields, got {arr.shape[1]}" - ) - pad = np.zeros((arr.shape[0], len(new_fields))) - return np.hstack([arr, pad]) - elif isinstance(arr, list): - return [expand_array(sub) for sub in arr] - else: - return arr + target = self.select_fields(*new_fields) + if ( + len(new_fields) > 1 + and isinstance(values, (list, tuple)) + and len(values) == len(new_fields) + ): + for field, value in zip(new_fields, values): + target.select_fields(field)[...] = value + else: + target[...] = values - self._data = expand_array(self._data) + if self._selected_fields is not None and tuple(old_fields) == self._selected_fields: + self._selected_fields = None - def remove_fields(self, fields_to_remove: Union[str, List[str]]) -> None: - """ - Remove fields from the vector. + def rename_fields(self, mapping: dict[str, str]) -> None: + """Rename one or more fields in-place. Parameters ---------- - fields_to_remove : Union[str, List[str]] - Field name(s) to remove. Must exist in the vector. - - Raises - ------ - ValueError - If any field doesn't exist + mapping : dict + Maps each old field name to its new name, e.g. + ``{"kx": "qx", "ky": "qy"}``. """ - if isinstance(fields_to_remove, str): - fields_to_remove = [fields_to_remove] - else: - fields_to_remove = list(fields_to_remove) + old_field_set = set(self._state["fields"]) + missing = [old for old in mapping if old not in old_field_set] + if missing: + raise KeyError(f"Unknown field(s): {missing}") + new_names = list(mapping.values()) + conflicts = [n for n in new_names if n in old_field_set and n not in mapping] + if conflicts: + raise ValueError(f"New field name(s) already exist: {conflicts}") + validate_fields(new_names) + + rename = {old: new for old, new in mapping.items()} + self._state["fields"] = [rename.get(f, f) for f in self._state["fields"]] + if self._selected_fields is not None: + self._selected_fields = tuple(rename.get(f, f) for f in self._selected_fields) + + def remove_fields(self, names: str | Sequence[str]) -> None: + """Remove one or more fields from the full Vector schema.""" + self._require_full_field_view("remove_fields") + to_remove = set(_normalize_field_names(names)) + old_fields = self._state["fields"] + old_units = self._state["units"] + + missing = [field for field in to_remove if field not in old_fields] + if missing: + raise KeyError(f"Unknown field(s): {missing}") + if len(to_remove) == len(old_fields): + raise ValueError("Cannot remove all fields from a Vector.") + + keep = [i for i, field in enumerate(old_fields) if field not in to_remove] + self._state["fields"] = [old_fields[i] for i in keep] + self._state["units"] = [old_units[i] for i in keep] + self._state["data"] = self._state["data"][:, keep] + + if self._selected_fields is not None: + self._selected_fields = tuple( + field for field in self._selected_fields if field in self._state["fields"] + ) + if len(self._selected_fields) == len(self._state["fields"]): + self._selected_fields = None - field_to_index = {name: i for i, name in enumerate(self._fields)} - indices_to_remove = [] - for field in fields_to_remove: - if field not in field_to_index: - print(f"Warning: field '{field}' not found.") - else: - indices_to_remove.append(field_to_index[field]) + # ------------------------------------------------------------------ # + # Cell / row mutation + # ------------------------------------------------------------------ # - if not indices_to_remove: - return + def append_rows(self, idx: Any, rows: Any) -> None: + """Append one or more rows to a single selected cell. - indices_to_remove = sorted(set(indices_to_remove)) - keep_indices = [i for i in range(self.num_fields) if i not in indices_to_remove] + ``idx`` is interpreted with the same fixed-grid indexing rules as + ``__getitem__`` and must resolve to exactly one cell. Appending rows is a + full-cell operation, so all fields must be selected. + """ + target = self[idx] + if target.shape != (): + raise ValueError("append_rows requires an index that selects exactly one cell.") + target._require_full_field_view("append_rows") - # Update metadata - self._fields = [self._fields[i] for i in keep_indices] - self._units = [self._units[i] for i in keep_indices] + new_rows = _coerce_cell_array(rows, target.num_fields) + if new_rows.shape[0] == 0: + return - def prune_array(arr: Any) -> Any: - if isinstance(arr, np.ndarray): - if arr.shape[1] < max(indices_to_remove) + 1: - raise ValueError( - f"Cannot remove field index {max(indices_to_remove)} from array with shape {arr.shape}" - ) - return arr[:, keep_indices] - elif isinstance(arr, list): - return [prune_array(sub) for sub in arr] - else: - return arr + cell_index = int(target._selected_cell_indices()[0]) + existing = target._cell_matrix(cell_index) + combined = np.vstack((existing, new_rows)) if existing.shape[0] > 0 else new_rows.copy() + target._replace_cells(np.array([cell_index], dtype=np.int64), [combined]) - self._data = prune_array(self._data) + def set_flattened(self, values: Any) -> None: + """Write values back in flattened row-major order. - def copy(self) -> "Vector": + This updates existing rows without changing per-cell row counts. It is + the rowwise companion to ``flatten()`` and is especially useful for + NumPy-based transforms that operate on all selected rows at once. """ - Create a deep copy of the vector. + field_indices = self._field_indices() + targets = self._selected_cell_indices() + row_counts = self.row_counts() + total_rows = sum(row_counts) + + if isinstance(values, Vector): + if values.num_fields != self.num_fields: + raise ValueError(f"Expected {self.num_fields} fields, got {values.num_fields}") + flat_values = values.flatten() + if flat_values.shape[0] != total_rows: + raise ValueError(f"Expected {total_rows} rows, got {flat_values.shape[0]}") + else: + flat_values = _broadcast_field_values(values, total_rows, self.num_fields) - Returns - ------- - Vector - A new Vector instance with the same data, shape, fields, and units. + cursor = 0 + for target, rows in zip(targets, row_counts): + cell = self._cell_matrix(int(target)) + if rows > 0: + cell[:, field_indices] = flat_values[cursor : cursor + rows] + cursor += rows + + def compact(self) -> None: + """Repack the backing row buffer to remove dead rows. + + Whole-cell replacement appends new rows and leaves previous rows unused + until compaction. Calling ``compact()`` makes memory usage and save size + predictable at the cost of reallocating the backing buffer. """ - import copy + data = self._state["data"] + used_rows = int(self._state["cell_lengths"].sum()) + if used_rows == 0: + self._state["data"] = np.empty((0, self._full_num_fields), dtype=data.dtype) + self._state["cell_starts"].fill(0) + return + + compacted = np.empty((used_rows, self._full_num_fields), dtype=data.dtype) + starts = np.zeros_like(self._state["cell_starts"]) + cursor = 0 + for linear_index in range(_cell_count(self._state["shape"])): + length = self._cell_row_count(linear_index) + starts[linear_index] = cursor + if length > 0: + cell = self._cell_matrix(linear_index) + compacted[cursor : cursor + length] = cell + cursor += length + self._state["data"] = compacted + self._state["cell_starts"] = starts + + # ------------------------------------------------------------------ # + # Python data model + # ------------------------------------------------------------------ # + + def __len__(self) -> int: + """Return ``shape[0]`` for non-scalar selections.""" + if self.shape == (): + raise TypeError("len() of unsized 0D Vector") + return self.shape[0] - vector_copy = Vector.from_shape( + def __repr__(self) -> str: + return "\n".join( + [ + f"quantem.Vector, shape={self.shape}, name={self.name}", + f" fields = {self.fields}", + f" units: {self.units}", + ] + ) + + __str__ = __repr__ + + def copy(self) -> "Vector": + """Return a deep copy of the current selection.""" + copied = self.__class__( shape=self.shape, - name=self.name, fields=self.fields, units=self.units, + name=self.name, + metadata=copy.deepcopy(self.metadata), + _token=self.__class__._token, + ) + target_cells = copied._selected_cell_indices() + source_arrays = [ + self._selected_cell_matrix(index).copy() for index in self._selected_cell_indices() + ] + copied._replace_cells(target_cells, source_arrays) + return copied + + def __getitem__(self, idx: Any) -> "Vector": + """Return a fixed-grid selection as another Vector view.""" + if _looks_like_field_selector(idx): + raise TypeError("Use select_fields(...) for field selection.") + if idx is Ellipsis: + return self + + selection_shape, selection_indices = _select_linear_indices( + self.shape, + self._selected_cell_indices(), + idx, + ) + return self._from_view( + self._state, + selection_shape, + selection_indices, + self._selected_fields, ) - vector_copy._data = copy.deepcopy(self._data) - return vector_copy - def flatten(self) -> NDArray: - """ - Flatten the vector into a 2D numpy array. + def __setitem__(self, idx: Any, value: Any) -> None: + """Assign to a fixed-grid selection.""" + if idx is Ellipsis: + target = self + else: + target = self[idx] + target._assign(value) - Returns - ------- - NDArray - A 2D numpy array containing all data, with shape (total_rows, num_fields). - """ + # ------------------------------------------------------------------ # + # Arithmetic operators + # ------------------------------------------------------------------ # - def collect_arrays(data: Any) -> List[NDArray]: - if isinstance(data, np.ndarray): - return [data] - elif isinstance(data, list): - arrays = [] - for item in data: - arrays.extend(collect_arrays(item)) - return arrays - else: - return [] - - arrays = collect_arrays(self._data) - if not arrays: - return np.empty((0, self.num_fields)) - return np.vstack(arrays) + def __array_ufunc__(self, ufunc: Any, method: str, *inputs: Any, **kwargs: Any) -> Any: + """Apply supported NumPy ufuncs elementwise. - def __repr__(self) -> str: - description = [ - f"quantem.Vector, shape={self._shape}, name={self._name}", - f" fields = {self._fields}", - f" units: {self._units}", + Supported operations are limited to elementwise ``__call__`` ufuncs. The + result preserves the current selection shape and fields. + """ + if method != "__call__": + return NotImplemented + + out = kwargs.get("out") + if out is not None: + return NotImplemented + + vector_inputs = [value for value in inputs if isinstance(value, Vector)] + if not vector_inputs: + return NotImplemented + + template = vector_inputs[0] + row_counts = template.row_counts() + total_rows = sum(row_counts) + + for other in vector_inputs[1:]: + if other.shape != template.shape: + raise ValueError("Vector ufunc inputs must have matching fixed-grid shapes.") + if other.num_fields != template.num_fields: + raise ValueError("Vector ufunc inputs must have matching field counts.") + if other.row_counts() != row_counts: + raise ValueError("Vector ufunc inputs must have matching per-cell row counts.") + + flat_inputs = [ + _normalize_ufunc_input(value, total_rows, template.num_fields) for value in inputs ] - return "\n".join(description) + result = ufunc(*flat_inputs, **kwargs) + if isinstance(result, tuple): + return tuple(_vector_from_flat_result(template, item, row_counts) for item in result) + return _vector_from_flat_result(template, result, row_counts) - def __str__(self) -> str: - description = [ - f"quantem.Vector, shape={self._shape}, name={self._name}", - f" fields = {self._fields}", - f" units: {self._units}", - ] - return "\n".join(description) + def __add__(self, other: Any) -> "Vector": + return self._binary_op(other, np.add) - @property - def metadata(self) -> dict: - return self._metadata + def __sub__(self, other: Any) -> "Vector": + return self._binary_op(other, np.subtract) - @property - def shape(self) -> Tuple[int, ...]: - """ - Get the shape of the vector. + def __mul__(self, other: Any) -> "Vector": + return self._binary_op(other, np.multiply) - Returns - ------- - Tuple[int, ...] - The dimensions of the vector. - """ - return self._shape + def __truediv__(self, other: Any) -> "Vector": + return self._binary_op(other, np.divide) - @shape.setter - def shape(self, value: Tuple[int, ...]) -> None: - """ - Set the shape of the vector. + def __floordiv__(self, other: Any) -> "Vector": + return self._binary_op(other, np.floor_divide) - Parameters - ---------- - value : Tuple[int, ...] - The new shape. All dimensions must be positive. - - Raises - ------ - ValueError - If any dimension is not positive. - TypeError - If value is not a tuple or contains non-integer values. - """ - self._shape = validate_shape(value) + def __mod__(self, other: Any) -> "Vector": + return self._binary_op(other, np.mod) - @property - def num_fields(self) -> int: - """ - Get the number of fields in the vector. + def __pow__(self, other: Any) -> "Vector": + return self._binary_op(other, np.power) - Returns - ------- - int - The number of fields. - """ - return len(self._fields) + def __radd__(self, other: Any) -> "Vector": + return self._binary_op(other, np.add, reverse=True) - @property - def name(self) -> str: - """ - Get the name of the vector. + def __rmul__(self, other: Any) -> "Vector": + return self._binary_op(other, np.multiply, reverse=True) - Returns - ------- - str - The name of the vector - """ - return self._name + def __rsub__(self, other: Any) -> "Vector": + return self._binary_op(other, np.subtract, reverse=True) - @name.setter - def name(self, value: str) -> None: - """ - Set the name of the vector. + def __rtruediv__(self, other: Any) -> "Vector": + return self._binary_op(other, np.divide, reverse=True) - Parameters - ---------- - value : str - The new name of the vector - """ - self._name = str(value) + def __rfloordiv__(self, other: Any) -> "Vector": + return self._binary_op(other, np.floor_divide, reverse=True) - @property - def fields(self) -> List[str]: - """ - Get the field names of the vector. + def __rmod__(self, other: Any) -> "Vector": + return self._binary_op(other, np.mod, reverse=True) - Returns - ------- - List[str] - The list of field names. - """ - return self._fields + def __rpow__(self, other: Any) -> "Vector": + return self._binary_op(other, np.power, reverse=True) - @fields.setter - def fields(self, value: List[str]) -> None: - """ - Set the field names of the vector. + def __iadd__(self, other: Any) -> "Vector": + self._inplace_op(other, np.add) + return self - Parameters - ---------- - value : List[str] - The new field names. Must match num_fields and be unique. - - Raises - ------ - ValueError - If length doesn't match num_fields or if there are duplicates. - TypeError - If value is not a list or contains non-string values. - """ - self._fields = validate_fields(value) + def __isub__(self, other: Any) -> "Vector": + self._inplace_op(other, np.subtract) + return self - @property - def units(self) -> List[str]: - """ - Get the units of the vector's fields. + def __imul__(self, other: Any) -> "Vector": + self._inplace_op(other, np.multiply) + return self - Returns - ------- - List[str] - The list of units, one per field. - """ - return self._units + def __itruediv__(self, other: Any) -> "Vector": + self._inplace_op(other, np.divide) + return self + + def __ifloordiv__(self, other: Any) -> "Vector": + self._inplace_op(other, np.floor_divide) + return self + + def __imod__(self, other: Any) -> "Vector": + self._inplace_op(other, np.mod) + return self - @units.setter - def units(self, value: List[str]) -> None: + def __ipow__(self, other: Any) -> "Vector": + self._inplace_op(other, np.power) + return self + + def __neg__(self) -> "Vector": + return self._binary_op(-1, np.multiply) + + def __pos__(self) -> "Vector": + return self.copy() + + def __abs__(self) -> "Vector": + result = self.copy() + result._inplace_unary(np.abs) + return result + + # ------------------------------------------------------------------ # + # I/O + # ------------------------------------------------------------------ # + + def save( + self, + path: str | Path, + mode: Literal["w", "o"] = "w", + store: Literal["auto", "zip", "dir"] = "auto", + skip: str | type | Sequence[str | type] = (), + compression_level: int | None = 4, + ) -> None: """ - Set the units of the vector's fields. + Save the Vector object to disk using Zarr serialization. self.compact() is called before + saving to reduce file size if possible. Parameters ---------- - value : List[str] - The new units. Must match num_fields. - - Raises - ------ - ValueError - If length doesn't match num_fields. - TypeError - If value is not a list or contains non-string values. + path : str or Path + Target file path. Use '.zip' extension for zip format, otherwise a directory. + mode : {'w', 'o'} + 'w' = write only if file doesn't exist, 'o' = overwrite if it does. + store : {'auto', 'zip', 'dir'} + Storage format. 'auto' infers from file extension. + skip : str, type, or list of (str or type) + Attribute names/types to skip (by name or type) during serialization. + compression_level : int or None + If set (0–9), applies Zstandard compression with Blosc backend at that level. + Level 0 disables compression. Raises ValueError if > 9. + + Notes + ----- + Skipped attribute names and types are also stored in the file metadata for correct + round-trip skipping during load(). """ - self._units = validate_vector_units(value, self.num_fields) + self.compact() + super().save( + path, + mode=mode, + store=store, + skip=skip, + compression_level=compression_level, + ) + + # ------------------------------------------------------------------ # + # Private helpers — backing-store access + # ------------------------------------------------------------------ # @property - def data(self) -> List[Any]: + def _full_num_fields(self) -> int: + return len(self._state["fields"]) + + def _field_indices(self) -> NDArray[np.int64]: + """Map selected field names to column indices in the backing buffer.""" + if self._selected_fields is None: + return np.arange(self._full_num_fields, dtype=np.int64) + + lookup = {field: i for i, field in enumerate(self._state["fields"])} + try: + return np.array([lookup[field] for field in self._selected_fields], dtype=np.int64) + except KeyError as exc: + raise KeyError(f"Unknown field(s): {[str(exc.args[0])]}") from exc + + def _require_full_field_view(self, operation: str) -> None: + """Raise if a schema-changing/full-row operation is attempted on a field view.""" + if self._selected_fields is not None: + raise ValueError(f"{operation} is only allowed when all fields are selected.") + + def _selected_cell_indices(self) -> NDArray[np.int64]: + """Return linear cell indices for the current fixed-grid selection.""" + if self._selection_indices is None: + return np.arange(_cell_count(self._state["shape"]), dtype=np.int64) + return self._selection_indices + + def _cell_row_count(self, linear_index: int) -> int: + """Return the row count for one cell in the backing buffer.""" + return int(self._state["cell_lengths"][linear_index]) + + def _cell_matrix(self, linear_index: int) -> NDArray[Any]: + """Return the full backing matrix for one cell.""" + start = int(self._state["cell_starts"][linear_index]) + length = int(self._state["cell_lengths"][linear_index]) + return self._state["data"][start : start + length] + + def _selected_cell_matrix(self, linear_index: int) -> NDArray[Any]: + """Return one cell with the current field selection applied.""" + cell = self._cell_matrix(linear_index) + cols = self._field_indices() + if cols.size == self._full_num_fields: + return cell + if cols.size == 1: + col = int(cols[0]) + return cell[:, col : col + 1] + if _is_contiguous(cols): + return cell[:, int(cols[0]) : int(cols[-1]) + 1] + return cell[:, cols].copy() + + def _replace_cells(self, targets: NDArray[np.int64], arrays: Sequence[NDArray[Any]]) -> None: + """Replace complete cells in the compact row buffer. + + Whole-cell replacement is implemented by appending the new payload rows to + the end of the backing buffer and then updating ``cell_starts`` / + ``cell_lengths`` for the targeted cells. This keeps the operation simple + and makes overlapping assignment semantics easy to reason about, but it + leaves the previous rows unreachable until compaction removes them. """ - Get the raw data of the vector. + if len(targets) != len(arrays): + raise ValueError("Target cell count does not match source cell count.") + if len(targets) == 0: + return - Returns - ------- - List[Any] - The nested list structure containing the vector's data. - """ - return self._data + normalized = [_coerce_cell_array(array, self._full_num_fields) for array in arrays] + payloads = [array for array in normalized if array.shape[0] > 0] + if payloads: + appended = np.vstack(payloads) + self._state["data"] = np.concatenate((self._state["data"], appended), axis=0) + + cursor = self._state["data"].shape[0] - sum(array.shape[0] for array in normalized) + for target, array in zip(targets, normalized): + self._state["cell_starts"][target] = cursor + self._state["cell_lengths"][target] = array.shape[0] + cursor += array.shape[0] + + self._maybe_compact_storage() + + def _expand_storage(self, num_new_fields: int) -> None: + """Append new ``np.nan``-initialized columns for added fields.""" + data = self._state["data"] + dtype = np.result_type(data.dtype, float) + if data.shape[0] == 0: + self._state["data"] = np.empty((0, data.shape[1] + num_new_fields), dtype=dtype) + return - @data.setter - def data(self, value: List[Any]) -> None: - """ - Set the raw data of the vector. + filler = np.full((data.shape[0], num_new_fields), np.nan, dtype=dtype) + self._state["data"] = np.concatenate((data.astype(dtype, copy=False), filler), axis=1) - Parameters - ---------- - value : List[Any] - The new data structure. Must match the vector's shape and num_fields. - - Raises - ------ - ValueError - If the data structure doesn't match shape or num_fields. - TypeError - If value is not a list or contains invalid data types. - """ - self._data = validate_vector_data(value, self.shape, self.num_fields) + def _maybe_compact_storage(self) -> None: + """Compact automatically once dead rows become materially larger than live rows.""" + data = self._state["data"] + used_rows = int(self._state["cell_lengths"].sum()) + if data.shape[0] <= used_rows + 1024 or data.shape[0] <= 2 * used_rows: + return + self.compact() + # ------------------------------------------------------------------ # + # Private helpers — assignment + # ------------------------------------------------------------------ # -# Helper function for nesting lists -def nested_list(shape: Tuple[int, ...], fill: Any = None) -> Any: - if len(shape) == 0: - return fill - return [nested_list(shape[1:], fill) for _ in range(shape[0])] + def _assign(self, value: Any) -> None: + """Dispatch assignment based on whether all fields or a subset are selected.""" + if self._selected_fields is None: + self._assign_full_cells(value) + else: + self._assign_selected_fields(value) + def _assign_full_cells(self, value: Any) -> None: + """Replace full cell payloads. -# Helper class for numerical field operations -class _FieldView: - def __init__(self, vector: Vector, field_name: str) -> None: - self.vector = vector - self.field_name = field_name - self.field_index = vector._fields.index(field_name) + Full-cell assignment may change the ragged row count of each targeted + cell, because the existing cell matrix is replaced as a whole. + """ + targets = self._selected_cell_indices() + if isinstance(value, Vector): + source_cells = value._selected_cell_indices() + if len(targets) != len(source_cells): + raise ValueError(f"Expected {len(targets)} cells, got {len(source_cells)}") + if value.num_fields != self.num_fields: + raise ValueError(f"Expected {self.num_fields} fields, got {value.num_fields}") + arrays = [value._selected_cell_matrix(index).copy() for index in source_cells] + self._replace_cells(targets, arrays) + return - def _apply_op(self, op: Any) -> None: - def apply(arr: Any) -> None: - if isinstance(arr, np.ndarray): - arr[:, self.field_index] = op(arr[:, self.field_index]) - elif isinstance(arr, list): - for sub in arr: - apply(sub) + array = _coerce_cell_array(value, self.num_fields) + self._replace_cells(targets, [array] * len(targets)) - apply(self.vector._data) + def _assign_selected_fields(self, value: Any) -> None: + """Update only the selected columns while preserving row counts. - def __iadd__(self, other: Union[float, int, np.ndarray]) -> "_FieldView": - """Handle in-place addition (+=).""" - self._apply_op(lambda x: x + other) - return self + This is the in-place path for assignments such as + ``vector.select_fields("kx")[...] = rhs``. The target cell structure is + preserved, so each target cell keeps its existing row count and only the + selected columns are overwritten. + """ + targets = self._selected_cell_indices() + field_indices = self._field_indices() + row_counts = [self._cell_row_count(index) for index in targets] + total_rows = sum(row_counts) + + if isinstance(value, Vector): + source_cells = value._selected_cell_indices() + if len(targets) != len(source_cells): + raise ValueError(f"Expected {len(targets)} cells, got {len(source_cells)}") + if value.num_fields != self.num_fields: + raise ValueError(f"Expected {self.num_fields} fields, got {value.num_fields}") + source_counts = [value._cell_row_count(index) for index in source_cells] + if row_counts != source_counts: + raise ValueError("Per-cell row counts must match for field-selected assignment.") + snapshots = [value._selected_cell_matrix(index).copy() for index in source_cells] + for target, array in zip(targets, snapshots): + cell = self._cell_matrix(int(target)) + if array.shape[0] > 0: + cell[:, field_indices] = array + return - def __isub__(self, other: Union[float, int, np.ndarray]) -> "_FieldView": - """Handle in-place subtraction (-=).""" - self._apply_op(lambda x: x - other) - return self + if np.isscalar(value): + for target in targets: + cell = self._cell_matrix(int(target)) + if cell.shape[0] > 0: + cell[:, field_indices] = value + return - def __imul__(self, other: Union[float, int, np.ndarray]) -> "_FieldView": - """Handle in-place multiplication (*=).""" - self._apply_op(lambda x: x * other) - return self + broadcast = _broadcast_field_values(value, total_rows, self.num_fields) + cursor = 0 + for target, rows in zip(targets, row_counts): + chunk = broadcast[cursor : cursor + rows] + cell = self._cell_matrix(int(target)) + if rows > 0: + cell[:, field_indices] = chunk + cursor += rows + + # ------------------------------------------------------------------ # + # Private helpers — arithmetic + # ------------------------------------------------------------------ # + + def _binary_op(self, other: Any, op: Any, reverse: bool = False) -> "Vector": + """Return a new Vector produced by elementwise arithmetic.""" + result = self.copy() + result._inplace_op(other, op, reverse=reverse) + return result - def __itruediv__(self, other: Union[float, int, np.ndarray]) -> "_FieldView": - """Handle in-place division (/=).""" - self._apply_op(lambda x: x / other) - return self + def _inplace_unary(self, op: Any) -> None: + """Apply a unary elementwise operation in-place to the selected fields.""" + targets = self._selected_cell_indices() + field_indices = self._field_indices() + for target in targets: + cell = self._cell_matrix(int(target)) + lhs = cell[:, field_indices] + if lhs.shape[0] > 0: + cell[:, field_indices] = op(lhs) + + def _inplace_op(self, other: Any, op: Any, reverse: bool = False) -> None: + """Apply elementwise arithmetic in-place to the selected fields.""" + targets = self._selected_cell_indices() + field_indices = self._field_indices() + row_counts = [self._cell_row_count(index) for index in targets] + total_rows = sum(row_counts) + + if isinstance(other, Vector): + source_cells = other._selected_cell_indices() + if len(targets) != len(source_cells): + raise ValueError(f"Expected {len(targets)} cells, got {len(source_cells)}") + if other.num_fields != self.num_fields: + raise ValueError(f"Expected {self.num_fields} fields, got {other.num_fields}") + source_counts = [other._cell_row_count(index) for index in source_cells] + if row_counts != source_counts: + raise ValueError("Per-cell row counts must match for Vector arithmetic.") + snapshots = [other._selected_cell_matrix(index).copy() for index in source_cells] + for target, rhs in zip(targets, snapshots): + cell = self._cell_matrix(int(target)) + lhs = cell[:, field_indices] + cell[:, field_indices] = op(rhs, lhs) if reverse else op(lhs, rhs) + return - def __ifloordiv__(self, other: Union[float, int, np.ndarray]) -> "_FieldView": - """Handle in-place floor division (//=).""" - self._apply_op(lambda x: x // other) - return self + if np.isscalar(other): + for target in targets: + cell = self._cell_matrix(int(target)) + lhs = cell[:, field_indices] + if lhs.shape[0] > 0: + cell[:, field_indices] = op(other, lhs) if reverse else op(lhs, other) + return - def __imod__(self, other: Union[float, int, np.ndarray]) -> "_FieldView": - """Handle in-place modulo (%=).""" - self._apply_op(lambda x: x % other) - return self + broadcast = _broadcast_field_values(other, total_rows, self.num_fields) + cursor = 0 + for target, rows in zip(targets, row_counts): + chunk = broadcast[cursor : cursor + rows] + cell = self._cell_matrix(int(target)) + lhs = cell[:, field_indices] + if rows > 0: + cell[:, field_indices] = op(chunk, lhs) if reverse else op(lhs, chunk) + cursor += rows + + +def _resolve_fields( + fields: Sequence[str] | None, + num_fields: int | None, + inferred: int | None, +) -> list[str]: + """Resolve field names from constructor arguments. + + ``inferred`` is the field count inferred from data; pass ``None`` when there + is no data source and explicit fields/num_fields are required. + """ + if fields is not None: + root_fields = validate_fields(list(fields)) + count = len(root_fields) + if num_fields is not None and count != num_fields: + raise ValueError( + f"num_fields ({num_fields}) does not match length of fields ({count})" + ) + if inferred is not None and count != inferred: + raise ValueError(f"num_fields ({inferred}) does not match length of fields ({count})") + return root_fields + if num_fields is not None: + count = validate_num_fields(num_fields) + if inferred is not None and count != inferred: + raise ValueError( + f"Provided num_fields ({count}) does not match inferred ({inferred})." + ) + return [f"field_{i}" for i in range(count)] + if inferred is not None: + return [f"field_{i}" for i in range(inferred)] + raise ValueError("Must specify either 'fields' or 'num_fields'.") + + +def _cell_count(shape: tuple[int, ...]) -> int: + """Return the number of fixed-grid cells in a shape.""" + return int(np.prod(shape, dtype=np.int64)) if shape else 1 + + +def _normalize_field_names(field_names: str | Sequence[str]) -> tuple[str, ...]: + """Normalize one-or-many field names into a validated tuple.""" + if isinstance(field_names, str): + normalized = (field_names,) + else: + normalized = tuple(field_names) + if not normalized: + raise ValueError("At least one field name is required.") + validate_fields(list(normalized)) + return normalized + + +def _normalize_units(units: str | Sequence[str] | None, count: int) -> list[str]: + """Normalize field units to a list matching ``count``.""" + if units is None: + return ["none"] * count + if isinstance(units, str): + if count != 1: + raise ValueError("A single unit can only be provided for a single field.") + return [units] + normalized = list(units) + if len(normalized) != count: + raise ValueError(f"Expected {count} units, got {len(normalized)}") + return normalized + + +def _looks_like_field_selector(idx: Any) -> bool: + """Return True for indices that look like field selection by mistake.""" + if isinstance(idx, str): + return True + if isinstance(idx, tuple) and any(_looks_like_field_selector(item) for item in idx): + return True + if isinstance(idx, list) and idx and all(isinstance(item, str) for item in idx): + return True + return False + + +def _coerce_cell_array(value: Any, num_fields: int) -> NDArray[Any]: + """Normalize a single-cell payload to shape ``(n_rows, num_fields)``.""" + if isinstance(value, Vector): + if value.shape != (): + raise ValueError("Expected a 0D Vector for single-cell assignment.") + array = value.array.copy() + else: + array = np.asarray(value) + + if array.ndim == 0: + raise ValueError("Cell assignment requires a 2D array.") + if array.ndim == 1: + if array.size == 0: + array = np.empty((0, num_fields), dtype=float) + elif num_fields == 1: + array = array.reshape(-1, 1) + else: + array = array.reshape(1, -1) + if array.ndim != 2: + raise ValueError("Cell assignment requires a 2D array.") + if array.shape[1] != num_fields: + raise ValueError(f"Expected {num_fields} fields, got {array.shape[1]}") + return array + + +def _flatten_fixed_grid(node: Any) -> tuple[tuple[int, ...], list[NDArray[Any]]]: + """Recursively flatten nested fixed-grid input into row-major cell order.""" + if isinstance(node, np.ndarray): + return (), [_coerce_inferred_cell_array(node)] + if not isinstance(node, (list, tuple)): + raise TypeError("Data must be a nested list/tuple of cell arrays or row sequences.") + if _looks_like_cell_rows(node): + return (), [_coerce_inferred_cell_array(node)] + if len(node) == 0: + return (0,), [] + + child_shape: tuple[int, ...] | None = None + cells: list[NDArray[Any]] = [] + for child in node: + shape, child_cells = _flatten_fixed_grid(child) + if child_shape is None: + child_shape = shape + elif child_shape != shape: + raise ValueError("All nested fixed-grid branches must have matching shapes.") + cells.extend(child_cells) + + assert child_shape is not None + return (len(node),) + child_shape, cells + + +def _looks_like_cell_rows(node: Sequence[Any]) -> bool: + """Return True when a sequence should be interpreted as cell rows, not grid nesting.""" + if len(node) == 0: + return True + return all(_is_row_like(item) for item in node) + + +def _is_row_like(item: Any) -> bool: + """Return True for a single row of scalar values.""" + if isinstance(item, np.ndarray): + return item.ndim == 1 + if not isinstance(item, (list, tuple)): + return False + return all(np.isscalar(value) for value in item) + + +def _coerce_inferred_cell_array(value: Any) -> NDArray[Any]: + """Infer a 2D cell array from row-like input during ``from_data``.""" + array = np.asarray(value) + if array.ndim == 0: + raise ValueError("Cell data must be 1D or 2D.") + if array.ndim == 1: + if array.size == 0: + return np.empty((0, 0), dtype=float) + return array.reshape(1, -1) + if array.ndim != 2: + raise ValueError("Cell data must be 1D or 2D.") + return array + + +def _select_linear_indices( + shape: tuple[int, ...], + current_indices: NDArray[np.int64], + idx: Any, +) -> tuple[tuple[int, ...], NDArray[np.int64]]: + """Apply fixed-grid indexing to a flattened cell-index view. + + ``current_indices`` stores the linear cell indices represented by the current + selection. This helper reshapes those indices to the current selection shape, + applies NumPy-like indexing on the fixed-grid axes, and then returns: + - the output fixed-grid shape + - the flattened linear indices of the selected cells, in row-major order + """ + if shape == (): + if idx in ((), Ellipsis): + return (), np.array([int(current_indices[0])], dtype=np.int64) + raise IndexError("Too many indices for 0D Vector") + + index_tuple = _normalize_index_tuple(idx, len(shape)) + current_grid = current_indices.reshape(shape) + + axis_positions: list[NDArray[np.int64]] = [] + out_shape: list[int] = [] + scalar_axes: list[bool] = [] + for axis, axis_index in enumerate(index_tuple): + positions, is_scalar = _positions_for_axis(axis_index, shape[axis]) + axis_positions.append(positions) + scalar_axes.append(is_scalar) + if not is_scalar: + out_shape.append(len(positions)) + + if all(scalar_axes): + scalar_key = tuple(int(positions[0]) for positions in axis_positions) + value = int(current_grid[scalar_key]) + return (), np.array([value], dtype=np.int64) + + mesh_inputs = [ + positions if not is_scalar else positions[:1] + for positions, is_scalar in zip(axis_positions, scalar_axes) + ] + grids = np.meshgrid(*mesh_inputs, indexing="ij") + selected = np.asarray(current_grid[tuple(grids)], dtype=np.int64).reshape(-1) + return tuple(out_shape), selected + + +def _normalize_index_tuple(idx: Any, ndim: int) -> tuple[Any, ...]: + """Normalize fixed-grid indexing to a full ``ndim``-length tuple.""" + if idx is Ellipsis: + return (slice(None),) * ndim + if not isinstance(idx, tuple): + idx = (idx,) + + ellipsis_count = sum(item is Ellipsis for item in idx) + if ellipsis_count > 1: + raise IndexError("An index can only have a single ellipsis.") + if ellipsis_count == 1: + ellipsis_pos = idx.index(Ellipsis) + fill = ndim - (len(idx) - 1) + idx = idx[:ellipsis_pos] + (slice(None),) * fill + idx[ellipsis_pos + 1 :] + if len(idx) > ndim: + raise IndexError(f"Too many indices for Vector: expected {ndim}, got {len(idx)}") + if len(idx) < ndim: + idx = idx + (slice(None),) * (ndim - len(idx)) + return idx + + +def _positions_for_axis(axis_index: Any, size: int) -> tuple[NDArray[np.int64], bool]: + """Resolve one axis index into concrete positions and scalar-vs-vector shape behavior.""" + if isinstance(axis_index, (bool, np.bool_)): + raise TypeError("Boolean scalars are not valid Vector indices.") + + if isinstance(axis_index, (int, np.integer)): + index = int(axis_index) + if index < 0: + index += size + if index < 0 or index >= size: + raise IndexError("Vector index out of range") + return np.array([index], dtype=np.int64), True + + if isinstance(axis_index, slice): + return np.arange(size, dtype=np.int64)[axis_index], False + + array = np.asarray(axis_index) + if array.ndim == 0: + if np.issubdtype(array.dtype, np.integer): + return _positions_for_axis(int(array.item()), size) + raise TypeError(f"Unsupported index type: {type(axis_index)!r}") + + if array.dtype == bool or np.issubdtype(array.dtype, np.bool_): + if array.ndim != 1: + raise IndexError("Full-grid boolean masks are not supported.") + if array.shape[0] != size: + raise IndexError( + f"Boolean mask length {array.shape[0]} does not match axis length {size}" + ) + return np.flatnonzero(array).astype(np.int64, copy=False), False + + if array.ndim != 1: + raise IndexError("Fancy indexing arrays must be one-dimensional.") + if array.size == 0: + return np.array([], dtype=np.int64), False + if not np.issubdtype(array.dtype, np.integer): + raise TypeError("Fancy indices must be integers or booleans.") + + positions = array.astype(np.int64, copy=True) + positions[positions < 0] += size + if np.any((positions < 0) | (positions >= size)): + raise IndexError("Vector index out of range") + return positions, False + + +def _broadcast_field_values(value: Any, total_rows: int, num_fields: int) -> NDArray[Any]: + """Broadcast array-like input to flattened rowwise assignment shape.""" + array = np.asarray(value) + if array.ndim == 0: + return np.broadcast_to(array.reshape(1, 1), (total_rows, num_fields)) + if num_fields == 1 and array.ndim == 1: + if total_rows == 0 and array.shape[0] == 0: + return array.reshape(0, 1) + if array.shape[0] != total_rows: + raise ValueError(f"Expected {total_rows} values, got {array.shape[0]}") + return array.reshape(total_rows, 1) + try: + return np.broadcast_to(array, (total_rows, num_fields)) + except ValueError as exc: + raise ValueError( + f"Cannot broadcast value with shape {array.shape} to ({total_rows}, {num_fields})" + ) from exc + + +def _normalize_ufunc_input(value: Any, total_rows: int, num_fields: int) -> Any: + """Normalize one ufunc input to flattened Vector-compatible form.""" + if isinstance(value, Vector): + return value.flatten() + if np.isscalar(value): + return value + return _broadcast_field_values(value, total_rows, num_fields) + + +def _vector_from_flat_result( + template: Vector, + values: Any, + row_counts: list[int], +) -> Vector: + """Build a Vector from flattened rowwise result data.""" + total_rows = sum(row_counts) + flat_values = _broadcast_field_values(values, total_rows, template.num_fields) + + result = Vector.from_shape( + shape=template.shape, + fields=template.fields, + units=template.units, + name=template.name, + ) + result._state["metadata"] = copy.deepcopy(template.metadata) - def __ipow__(self, other: Union[float, int, np.ndarray]) -> "_FieldView": - """Handle in-place power (**=).""" - self._apply_op(lambda x: x**other) - return self + if total_rows == 0: + result._state["data"] = np.empty((0, template.num_fields), dtype=flat_values.dtype) + return result + + cursor = 0 + cells: list[NDArray[Any]] = [] + for rows in row_counts: + cells.append(flat_values[cursor : cursor + rows].copy()) + cursor += rows + + result._replace_cells(result._selected_cell_indices(), cells) + return result - def flatten(self) -> NDArray: - def collect(arr: Any) -> List[NDArray]: - if isinstance(arr, np.ndarray): - return [arr[:, self.field_index]] - elif isinstance(arr, list): - result = [] - for sub in arr: - result.extend(collect(sub)) - return result - else: - return [] - - arrays = collect(self.vector._data) - if not arrays: - return np.empty((0,), dtype=float) - return np.concatenate(arrays, axis=0) - - def set_flattened(self, values: ArrayLike) -> None: - """ - Set the field values across the entire Vector from a 1D flattened array. - """ - def fill(arr: Any, values: NDArray, cursor: int) -> int: - if isinstance(arr, np.ndarray): - n = arr.shape[0] - arr[:, self.field_index] = values[cursor : cursor + n] - return cursor + n - elif isinstance(arr, list): - for sub in arr: - cursor = fill(sub, values, cursor) - return cursor - return cursor - - values = np.asarray(values) - if values.ndim != 1: - raise ValueError("Input to set_flattened must be a 1D array.") - - expected = self.flatten().shape[0] - if values.shape[0] != expected: - raise ValueError(f"Expected {expected} values, got {values.shape[0]}") - - fill(self.vector._data, values, cursor=0) - - def __getitem__( - self, idx: Union[Tuple[Union[int, slice], ...], int, slice] - ) -> Union[NDArray, "_FieldView"]: - # Optionally allow v['field0'][0, 1] to get subregion, or v['field0'][...] slice - sub = self.vector[idx] - if isinstance(sub, Vector): - return sub[self.field_name] - elif isinstance(sub, np.ndarray): - return sub[:, self.field_index] - return cast(NDArray, None) - - def __array__(self) -> np.ndarray: - """Convert to numpy array when needed.""" - return self.flatten() +def _is_contiguous(indices: NDArray[np.int64]) -> bool: + """Return True when integer column indices form one contiguous slice.""" + if indices.size <= 1: + return True + return bool(np.all(indices[1:] - indices[:-1] == 1)) diff --git a/src/quantem/core/visualization/visualization_utils.py b/src/quantem/core/visualization/visualization_utils.py index afe475b1..b69e3c86 100644 --- a/src/quantem/core/visualization/visualization_utils.py +++ b/src/quantem/core/visualization/visualization_utils.py @@ -503,6 +503,14 @@ def bilinear_histogram_2d( x0, y0 = origin x1, y1 = x0 + Nx * dx, y0 + Ny * dy + x = _as_histogram_vector(x, "x") + y = _as_histogram_vector(y, "y") + weight = _as_histogram_vector(weight, "weight") + if not (x.shape == y.shape == weight.shape): + raise ValueError( + f"x, y, and weight must have matching shapes after coercion, got {x.shape}, {y.shape}, {weight.shape}" + ) + # Convert shape tuple to list for binned_statistic_2d bins: Sequence[int] = [Nx, Ny] hist, _, _, _ = binned_statistic_2d( @@ -517,6 +525,15 @@ def bilinear_histogram_2d( return hist # shape = (Nx, Ny), i.e., array[x, y] +def _as_histogram_vector(value: NDArray, name: str) -> NDArray: + array = np.asarray(value) + if array.ndim == 1: + return array + if array.ndim == 2 and 1 in array.shape: + return array.reshape(-1) + raise ValueError(f"{name} must be 1D or shape (N, 1)/(1, N), got {array.shape}") + + def axes_with_inset( axsize=(4, 4), ax_size_embed=None, # None -> 0.25 of main axes in each dimension (fractional) diff --git a/tests/datastructures/test_vector.py b/tests/datastructures/test_vector.py index e2085b76..954059d6 100644 --- a/tests/datastructures/test_vector.py +++ b/tests/datastructures/test_vector.py @@ -1,433 +1,436 @@ +import zipfile + import numpy as np import pytest from quantem.core.datastructures.vector import Vector +from quantem.core.io.serialize import load -class TestVector: - """Test suite for the Vector class.""" +def make_line_vector() -> Vector: + v = Vector.from_shape( + shape=(4,), + fields=["intensity", "kx", "ky"], + units=["a.u.", "px", "px"], + name="line", + ) + v[0] = np.array([[1.0, 10.0, 100.0], [2.0, 20.0, 200.0]]) + v[1] = np.array([[3.0, 30.0, 300.0]]) + v[2] = np.array([[4.0, 40.0, 400.0], [5.0, 50.0, 500.0]]) + v[3] = np.array([[6.0, 60.0, 600.0]]) + return v + + +def make_grid_vector() -> Vector: + v = Vector.from_shape(shape=(3, 2), fields=["intensity", "kx", "ky"]) + for i in range(3): + for j in range(2): + base = float(i * 10 + j) + v[i, j] = np.array([[base, base + 100.0, base + 200.0]]) + return v - def test_initialization(self): - """Test Vector initialization with different parameters.""" - # Test with fields - v1 = Vector.from_shape(shape=(2, 3), fields=["field0", "field1", "field2"]) + +class TestVector: + def test_initialization_and_len(self): + v1 = Vector.from_shape(shape=(2, 3), fields=["a", "b", "c"]) assert v1.shape == (2, 3) + assert len(v1) == 2 + assert v1.num_cells == 6 assert v1.num_fields == 3 - assert v1.fields == ["field0", "field1", "field2"] + assert v1.dtype == np.dtype(float) + assert v1.fields == ["a", "b", "c"] assert v1.units == ["none", "none", "none"] assert v1.name == "2d ragged array" - assert hasattr(v1, "metadata") - - # Test with num_fields - v2 = Vector.from_shape(shape=(2, 3), num_fields=3) - assert v2.shape == (2, 3) - assert v2.num_fields == 3 - assert v2.fields == ["field_0", "field_1", "field_2"] - assert v2.units == ["none", "none", "none"] - assert hasattr(v2, "metadata") - - # Test with custom name and units - v3 = Vector.from_shape( - shape=(2, 3), - fields=["field0", "field1", "field2"], - name="my_vector", - units=["unit0", "unit1", "unit2"], - ) - assert v3.name == "my_vector" - assert v3.units == ["unit0", "unit1", "unit2"] - assert hasattr(v3, "metadata") + assert v1[0, 0].array.shape == (0, 3) + np.testing.assert_array_equal(v1[0, 0].flatten(), v1[0, 0].array) + + v2 = Vector.from_shape(shape=(2, 3), num_fields=2) + assert v2.fields == ["field_0", "field_1"] + + with pytest.raises(TypeError): + len(v1[0, 0]) - # Test error cases with pytest.raises(ValueError, match="Must specify either 'fields' or 'num_fields'."): Vector.from_shape(shape=(2, 3)) with pytest.raises(ValueError, match="does not match length of fields"): - Vector.from_shape(shape=(2, 3), num_fields=3, fields=["field0", "field1"]) + Vector.from_shape(shape=(2, 3), num_fields=2, fields=["a", "b", "c"]) with pytest.raises(ValueError, match="Duplicate field names"): - Vector.from_shape(shape=(2, 3), fields=["field0", "field0", "field2"]) + Vector.from_shape(shape=(2, 3), fields=["a", "a"]) - def test_data_access(self): - """Test data access and assignment.""" - v = Vector.from_shape(shape=(2, 3), fields=["field0", "field1", "field2"]) + assert str(v1) == ( + "quantem.Vector, shape=(2, 3), name=2d ragged array\n" + " fields = ['a', 'b', 'c']\n" + " units: ['none', 'none', 'none']" + ) - # Set data at specific indices - data1 = np.array([[1.0, 2.0, 3.0]]) - v[0, 0] = data1 - np.testing.assert_array_equal(v.get_data(0, 0), data1) # type: ignore + def test_indexing_and_array_contract(self): + v = make_grid_vector() - # Test get_data method - assert np.array_equal(v.get_data(0, 0), data1) + assert isinstance(v[:2, 1], Vector) + assert v[:2, 1].shape == (2,) + assert v[1].shape == (2,) + assert v[1, 1].shape == () + np.testing.assert_array_equal(v[-1, -1].array, np.array([[21.0, 121.0, 221.0]])) - # Test set_data method - data2 = np.array([[4.0, 5.0, 6.0]]) - v.set_data(data2, 0, 1) - assert np.array_equal(v.get_data(0, 1), data2) + with pytest.raises(ValueError): + _ = v[:, 1].array - # Test error cases - with pytest.raises(IndexError): - v[2, 0] = data1 # Out of bounds + result = v[[-1, 0], 1] + assert result.shape == (2,) + assert result.num_cells == 2 + np.testing.assert_array_equal(result[0].array, np.array([[21.0, 121.0, 221.0]])) + np.testing.assert_array_equal(result[1].array, np.array([[1.0, 101.0, 201.0]])) - with pytest.raises(ValueError): - v[0, 0] = np.array([[1.0, 2.0]]) # Wrong number of fields + def test_select_fields_and_chaining_equivalence(self): + v = make_line_vector() - with pytest.raises(ValueError): - v.set_data(np.array([[1.0, 2.0]]), 0, 0) # Wrong number of fields - - def test_field_operations(self): - """Test field-level operations.""" - v = Vector.from_shape(shape=(2, 3), fields=["field0", "field1", "field2"]) - - # Set initial data - v[0, 0] = np.array([[1.0, 2.0, 3.0]]) - v[0, 1] = np.array([[4.0, 5.0, 6.0]]) - v[0, 2] = np.array([[7.0, 8.0, 9.0]]) - - # Test field access - field_view = v["field0"] - assert ( - hasattr(field_view, "vector") - and hasattr(field_view, "field_name") - and hasattr(field_view, "field_index") + selected = v.select_fields("kx") + assert selected.fields == ["kx"] + assert selected.units == ["px"] + assert selected.shape == v.shape + + np.testing.assert_array_equal( + v.select_fields("kx")[2].array, + v[2].select_fields("kx").array, ) - # Test field operations - v["field0"] += 10 # type: ignore - np.testing.assert_array_equal(v.get_data(0, 0)[:, 0], np.array([11.0])) # type: ignore - np.testing.assert_array_equal(v.get_data(0, 1)[:, 0], np.array([14.0])) # type: ignore - np.testing.assert_array_equal(v.get_data(0, 2)[:, 0], np.array([17.0])) # type: ignore + with pytest.raises(KeyError): + v.select_fields("missing") - # Test applying a function to a field - v["field1"] *= 2 # Using multiplication instead of lambda # type: ignore - np.testing.assert_array_equal(v.get_data(0, 0)[:, 1], np.array([4.0])) # type: ignore - np.testing.assert_array_equal(v.get_data(0, 1)[:, 1], np.array([10.0])) # type: ignore - np.testing.assert_array_equal(v.get_data(0, 2)[:, 1], np.array([16.0])) # type: ignore + with pytest.raises(TypeError): + _ = v["kx"] - # Test field flattening - flat = v["field2"].flatten() - np.testing.assert_array_equal(flat, np.array([3.0, 6.0, 9.0])) # type: ignore + with pytest.raises(TypeError): + _ = v[1, "kx"] + + multi = v.select_fields("intensity", "kx") + assert multi.fields == ["intensity", "kx"] + assert multi.dtype == np.dtype(float) + assert multi.total_rows == 6 + assert multi.row_counts() == [2, 1, 2, 1] + + def test_array_mutation_writes_through_for_single_field(self): + v = make_line_vector() + cell = v.select_fields("kx")[1].array + cell[0, 0] = 99.0 + assert v[1].array[0, 1] == 99.0 + + def test_set_flattened_updates_rowwise(self): + v = make_line_vector() + kx = v.select_fields("kx") + + flat_kx = kx.flatten() + mask = flat_kx >= 30.0 + flat_kx[mask[:, 0], 0] = -1.0 + kx.set_flattened(flat_kx) + + np.testing.assert_array_equal( + kx.flatten(), + np.array([[10.0], [20.0], [-1.0], [-1.0], [-1.0], [-1.0]]), + ) - # Test setting flattened data - v["field2"].set_flattened(np.array([18.0, 18.0, 18.0])) + def test_field_arithmetic_with_scalar_and_ndarray(self): + v = make_line_vector() - # Test error cases - with pytest.raises(KeyError): - v["nonexistent_field"] + kx = v.select_fields("kx") + kx += 10 + np.testing.assert_array_equal( + v.select_fields("kx").flatten(), + np.array([[20.0], [30.0], [40.0], [50.0], [60.0], [70.0]]), + ) - with pytest.raises(ValueError): - v["field0"].set_flattened(np.array([1.0, 2.0])) # Wrong length - - def test_slicing(self): - """Test slicing operations.""" - v = Vector.from_shape(shape=(4, 3), fields=["field0", "field1", "field2"]) - - # Set data - for i in range(4): - for j in range(3): - v[i, j] = np.array( - [[float(i * 3 + j), float(i * 3 + j + 1), float(i * 3 + j + 2)]] - ) - - # Test slicing - sliced = v[1:3, 1] - assert isinstance(sliced, Vector) - assert sliced.shape == (2, 1) - - # Compare arrays directly - expected1 = np.array([[4.0, 5.0, 6.0]]) - expected2 = np.array([[7.0, 8.0, 9.0]]) - np.testing.assert_array_equal(sliced.get_data(0, 0), expected1) # type: ignore - np.testing.assert_array_equal(sliced.get_data(1, 0), expected2) # type: ignore - - # Test field access on sliced vector - field_sliced = sliced["field1"] - np.testing.assert_array_equal(field_sliced.flatten(), np.array([5.0, 8.0])) # type: ignore - - # Test copying slices of vectors - v[2:4, 1] = v[1:3, 1] - - # Test copying slices of vectors with fancy indexing - v[[0, 1], 1] = v[[2, 3], 0] - - def test_field_management(self): - """Test adding and removing fields.""" - v = Vector.from_shape(shape=(2, 3), fields=["field0", "field1", "field2"]) - - # Set initial data - v[0, 0] = np.array([[1.0, 2.0, 3.0]]) - - # Test adding fields - v.add_fields(["field3", "field4"]) - assert v.num_fields == 5 - assert v.fields == ["field0", "field1", "field2", "field3", "field4"] - assert v.units == ["none", "none", "none", "none", "none"] - - # Check that new fields are initialized to zeros - np.testing.assert_array_equal(v.get_data(0, 0)[:, 3:5], np.array([[0.0, 0.0]])) # type: ignore - - # Test removing fields - v.remove_fields(["field1", "field3"]) - assert v.num_fields == 3 - assert v.fields == ["field0", "field2", "field4"] - assert v.units == ["none", "none", "none"] - - # Check that data is preserved for remaining fields - np.testing.assert_array_equal(v.get_data(0, 0)[:, 0], np.array([1.0])) # type: ignore - np.testing.assert_array_equal(v.get_data(0, 0)[:, 1], np.array([3.0])) # type: ignore - np.testing.assert_array_equal(v.get_data(0, 0)[:, 2], np.array([0.0])) # type: ignore - - # Test error cases - with pytest.raises(ValueError): - v.add_fields(["field0"]) # Duplicate field - - v.remove_fields(["nonexistent_field"]) # Should just print a warning - - def test_copy(self): - """Test deep copying.""" - v = Vector.from_shape(shape=(2, 3), fields=["field0", "field1", "field2"]) - v[0, 0] = np.array([[1.0, 2.0, 3.0]]) - - # Create a copy - v_copy = v.copy() - - # Check that it's a deep copy - assert v_copy is not v - assert v_copy.shape == v.shape - assert v_copy.fields == v.fields - assert v_copy.units == v.units - np.testing.assert_array_equal(v_copy.get_data(0, 0), v.get_data(0, 0)) # type: ignore - - # Modify the copy and check that the original is unchanged - v_copy[0, 0] = np.array([[4.0, 5.0, 6.0]]) - np.testing.assert_array_equal(v.get_data(0, 0), np.array([[1.0, 2.0, 3.0]])) # type: ignore - - def test_flatten(self): - """Test flattening the entire vector.""" - v = Vector.from_shape(shape=(2, 3), fields=["field0", "field1", "field2"]) - - # Set data - v[0, 0] = np.array([[1.0, 2.0, 3.0]]) - v[0, 1] = np.array([[4.0, 5.0, 6.0]]) - v[0, 2] = np.array([[7.0, 8.0, 9.0]]) - v[1, 0] = np.array([[10.0, 11.0, 12.0]]) - v[1, 1] = np.array([[13.0, 14.0, 15.0]]) - v[1, 2] = np.array([[16.0, 17.0, 18.0]]) - - # Flatten the vector - flattened = v.flatten() - - # Check the flattened array - expected = np.array( - [ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - [7.0, 8.0, 9.0], - [10.0, 11.0, 12.0], - [13.0, 14.0, 15.0], - [16.0, 17.0, 18.0], - ] + v.select_fields("kx")[...] += np.arange(6) + np.testing.assert_array_equal( + v.select_fields("kx").flatten(), + np.array([[20.0], [31.0], [42.0], [53.0], [64.0], [75.0]]), ) - np.testing.assert_array_equal(flattened, expected) # type: ignore - def test_from_data(self): - """Test creating a Vector from ragged lists or numpy arrays.""" - # Create test data - data = [ - np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), - np.array([[7.0, 8.0, 9.0]]), - np.array([[10.0, 11.0, 12.0], [13.0, 14.0, 15.0], [16.0, 17.0, 18.0]]), - ] + summed = v.select_fields("intensity") + v.select_fields("ky") + np.testing.assert_array_equal( + summed.flatten(), + np.array([[101.0], [202.0], [303.0], [404.0], [505.0], [606.0]]), + ) - # Test with explicit fields - v1 = Vector.from_data( - data=data, - fields=["field0", "field1", "field2"], - name="test_vector", - units=["unit0", "unit1", "unit2"], + def test_power_operations(self): + v = make_line_vector() + + squared = v.select_fields("intensity") ** 2 + np.testing.assert_array_equal( + squared.flatten(), + np.array([[1.0], [4.0], [9.0], [16.0], [25.0], [36.0]]), ) - # Check properties - assert v1.shape == (3,) - assert v1.num_fields == 3 - assert v1.fields == ["field0", "field1", "field2"] - assert v1.units == ["unit0", "unit1", "unit2"] - assert v1.name == "test_vector" + intensity = v.select_fields("intensity") + intensity **= 2 + np.testing.assert_array_equal( + intensity.flatten(), + np.array([[1.0], [4.0], [9.0], [16.0], [25.0], [36.0]]), + ) - # Check data - np.testing.assert_array_equal(v1.get_data(0), data[0]) # type: ignore - np.testing.assert_array_equal(v1.get_data(1), data[1]) # type: ignore - np.testing.assert_array_equal(v1.get_data(2), data[2]) # type: ignore + reverse = 2 ** v.select_fields("intensity") + np.testing.assert_array_equal( + reverse.flatten(), + np.array([[2.0], [16.0], [512.0], [65536.0], [33554432.0], [68719476736.0]]), + ) - # Test with inferred fields - v2 = Vector.from_data(data=data, num_fields=3) + def test_unary_mod_and_floor_division_operations(self): + v = make_line_vector() - # Check properties - assert v2.shape == (3,) - assert v2.num_fields == 3 - assert v2.fields == ["field_0", "field_1", "field_2"] - assert v2.units == ["none", "none", "none"] + negative = -v.select_fields("intensity") + np.testing.assert_array_equal( + negative.flatten(), + np.array([[-1.0], [-2.0], [-3.0], [-4.0], [-5.0], [-6.0]]), + ) - # Check data - np.testing.assert_array_equal(v2.get_data(0), data[0]) # type: ignore - np.testing.assert_array_equal(v2.get_data(1), data[1]) # type: ignore - np.testing.assert_array_equal(v2.get_data(2), data[2]) # type: ignore + absolute = abs(negative) + np.testing.assert_array_equal( + absolute.flatten(), + np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]), + ) - # Test error cases - with pytest.raises(TypeError, match="Data must be a list"): - Vector.from_data(data=np.array([1, 2, 3])) # type: ignore + floored = v.select_fields("ky") // 150 + np.testing.assert_array_equal( + floored.flatten(), + np.array([[0.0], [1.0], [2.0], [2.0], [3.0], [4.0]]), + ) - with pytest.raises(ValueError, match="does not match length of fields"): - Vector.from_data( - data=data, - fields=["field0", "field1"], # Wrong number of fields - ) + modded = v.select_fields("ky") % 150 + np.testing.assert_array_equal( + modded.flatten(), + np.array([[100.0], [50.0], [0.0], [100.0], [50.0], [0.0]]), + ) - with pytest.raises(ValueError, match="Duplicate field names"): - Vector.from_data( - data=data, - fields=["field0", "field0", "field2"], # Duplicate field names - ) - - def test_fancy_indexing(self): - """Test fancy indexing with __getitem__ and __setitem__.""" - v = Vector.from_shape(shape=(3, 2), fields=["field0", "field1", "field2"]) - - # Set initial data - v[0, 0] = np.array([[1.0, 2.0, 3.0]]) - v[0, 1] = np.array([[4.0, 5.0, 6.0]]) - v[1, 0] = np.array([[7.0, 8.0, 9.0]]) - v[1, 1] = np.array([[10.0, 11.0, 12.0]]) - v[2, 0] = np.array([[13.0, 14.0, 15.0]]) - v[2, 1] = np.array([[16.0, 17.0, 18.0]]) - - # Test list indexing with __getitem__ - result = v[[0, 1], 0] - assert isinstance(result, Vector) - assert result.shape == (2, 1) - np.testing.assert_array_equal(result.get_data(0, 0), np.array([[1.0, 2.0, 3.0]])) - np.testing.assert_array_equal(result.get_data(1, 0), np.array([[7.0, 8.0, 9.0]])) - - # Test numpy array indexing with __getitem__ - result = v[np.array([1, 2]), 1] # type: ignore - assert isinstance(result, Vector) - assert result.shape == (2, 1) - np.testing.assert_array_equal(result.get_data(0, 0), np.array([[10.0, 11.0, 12.0]])) - np.testing.assert_array_equal(result.get_data(1, 0), np.array([[16.0, 17.0, 18.0]])) - - # Test fancy indexing with __setitem__ - new_data = [np.array([[20.0, 21.0, 22.0]]), np.array([[23.0, 24.0, 25.0]])] - v[[0, 2], 1] = new_data - np.testing.assert_array_equal(v.get_data(0, 1), new_data[0]) - np.testing.assert_array_equal(v.get_data(2, 1), new_data[1]) - - # Test numpy array fancy indexing with __setitem__ - new_data = [np.array([[26.0, 27.0, 28.0]]), np.array([[29.0, 30.0, 31.0]])] - v[np.array([1, 2]), 0] = new_data # type: ignore - np.testing.assert_array_equal(v.get_data(1, 0), new_data[0]) - np.testing.assert_array_equal(v.get_data(2, 0), new_data[1]) - - # Test error cases - with pytest.raises(IndexError): - v[[3, 4], 0] # Index out of bounds + ky = v.select_fields("ky") + ky //= 150 + np.testing.assert_array_equal( + ky.flatten(), + np.array([[0.0], [1.0], [2.0], [2.0], [3.0], [4.0]]), + ) - with pytest.raises(IndexError): - v[[0, 1], 2] # Index out of bounds + intensity = v.select_fields("intensity") + intensity %= 2 + np.testing.assert_array_equal( + intensity.flatten(), + np.array([[1.0], [0.0], [1.0], [0.0], [1.0], [0.0]]), + ) - with pytest.raises(ValueError): - v[[0, 1], 0] = [np.array([[1.0]])] # Wrong number of arrays + def test_numpy_ufunc_support(self): + v = make_line_vector() - with pytest.raises(ValueError): - v[[0, 1], 0] = [ - np.array([[1.0]]), - np.array([[2.0]]), - ] # Wrong number of fields - - def test_get_data_methods(self): - """Test get_data method with various indexing scenarios.""" - v = Vector.from_shape(shape=(3, 2), fields=["field0", "field1", "field2"]) - - # Set some test data - v[0, 0] = np.array([[1.0, 2.0, 3.0]]) - v[0, 1] = np.array([[4.0, 5.0, 6.0]]) - v[1, 0] = np.array([[7.0, 8.0, 9.0]]) - v[1, 1] = np.array([[10.0, 11.0, 12.0]]) - v[2, 0] = np.array([[13.0, 14.0, 15.0]]) - v[2, 1] = np.array([[16.0, 17.0, 18.0]]) - - # Test single integer indexing - result = v.get_data(0, 0) - np.testing.assert_array_equal(result, np.array([[1.0, 2.0, 3.0]])) - - # Test list indexing - result = v.get_data([0, 1], 0) - np.testing.assert_array_equal(result[0], np.array([[1.0, 2.0, 3.0]])) - np.testing.assert_array_equal(result[1], np.array([[7.0, 8.0, 9.0]])) - - # Test numpy array indexing - result = v.get_data(np.array([1, 2]), 1) - np.testing.assert_array_equal(result[0], np.array([[10.0, 11.0, 12.0]])) - np.testing.assert_array_equal(result[1], np.array([[16.0, 17.0, 18.0]])) - - # Test slice indexing - result = v.get_data(slice(1, 3), 0) - np.testing.assert_array_equal(result[0], np.array([[7.0, 8.0, 9.0]])) - np.testing.assert_array_equal(result[1], np.array([[13.0, 14.0, 15.0]])) - - # Test error cases - with pytest.raises(ValueError, match="Expected 2 indices"): - v.get_data(0) # Too few indices - - with pytest.raises(ValueError, match="Expected 2 indices"): - v.get_data(0, 0, 0) # Too many indices + sine = np.sin(v.select_fields("kx")) + np.testing.assert_allclose( + sine.flatten(), + np.sin(v.select_fields("kx").flatten()), + ) - with pytest.raises(IndexError): - v.get_data(3, 0) # Index out of bounds + maximum = np.maximum(v.select_fields("intensity"), 3.0) # type: ignore[arg-type] + np.testing.assert_array_equal( + maximum.flatten(), + np.array([[3.0], [3.0], [3.0], [4.0], [5.0], [6.0]]), + ) - with pytest.raises(IndexError): - v.get_data([3, 4], 0) # List index out of bounds - - def test_set_data_methods(self): - """Test set_data method with various indexing scenarios.""" - v = Vector.from_shape(shape=(3, 2), fields=["field0", "field1", "field2"]) - - # Test single integer indexing - data1 = np.array([[1.0, 2.0, 3.0]]) - v.set_data(data1, 0, 0) - np.testing.assert_array_equal(v.get_data(0, 0), data1) - - # Test list indexing - data2 = [np.array([[4.0, 5.0, 6.0]]), np.array([[7.0, 8.0, 9.0]])] - v.set_data(data2, [0, 1], 1) - np.testing.assert_array_equal(v.get_data(0, 1), data2[0]) - np.testing.assert_array_equal(v.get_data(1, 1), data2[1]) - - # Test numpy array indexing - data3 = [np.array([[10.0, 11.0, 12.0]]), np.array([[13.0, 14.0, 15.0]])] - v.set_data(data3, np.array([1, 2]), 0) - np.testing.assert_array_equal(v.get_data(1, 0), data3[0]) - np.testing.assert_array_equal(v.get_data(2, 0), data3[1]) - - # Test slice indexing - data4 = [np.array([[16.0, 17.0, 18.0]]), np.array([[19.0, 20.0, 21.0]])] - v.set_data(data4, slice(1, 3), 1) - np.testing.assert_array_equal(v.get_data(1, 1), data4[0]) - np.testing.assert_array_equal(v.get_data(2, 1), data4[1]) - - # Test error cases - with pytest.raises(ValueError, match="Expected 2 indices"): - v.set_data(data1, 0) # Too few indices - - with pytest.raises(ValueError, match="Expected 2 indices"): - v.set_data(data1, 0, 0, 0) # Too many indices + frac, whole = np.modf(v.select_fields("intensity") / 2.0) + np.testing.assert_allclose( + frac.flatten(), + np.array([[0.5], [0.0], [0.5], [0.0], [0.5], [0.0]]), + ) + np.testing.assert_allclose( + whole.flatten(), + np.array([[0.0], [1.0], [1.0], [2.0], [2.0], [3.0]]), + ) - with pytest.raises(IndexError): - v.set_data(data1, 3, 0) # Index out of bounds + def test_field_assignment_from_vector_expression(self): + v = make_line_vector() + scale = 2.5 + + v[:2].select_fields("intensity")[...] = v[2:4].select_fields("intensity") * scale + np.testing.assert_array_equal( + v[:2].select_fields("intensity").flatten(), + np.array([[10.0], [12.5], [15.0]]), + ) + + def test_field_assignment_requires_matching_per_cell_row_counts(self): + v = make_line_vector() + with pytest.raises(ValueError, match="Per-cell row counts must match"): + v[:2].select_fields("intensity")[...] = v[1:3].select_fields("intensity") + + def test_full_cell_assignment_allows_row_count_changes(self): + v = make_line_vector() + + v[1] = v[0] + assert v[1].array.shape == (2, 3) + np.testing.assert_array_equal(v[1].array, v[0].array) + + v[0:2] = v[1:3] + assert v[0].array.shape == (2, 3) + assert v[1].array.shape == (2, 3) + + broadcast_cell = np.array([[9.0, 8.0, 7.0]]) + v[[0, 3]] = broadcast_cell + np.testing.assert_array_equal(v[0].array, broadcast_cell) + np.testing.assert_array_equal(v[3].array, broadcast_cell) + + def test_append_rows_and_compact(self): + v = make_line_vector() + + v.append_rows(1, np.array([[7.0, 70.0, 700.0]])) + np.testing.assert_array_equal( + v[1].array, + np.array([[3.0, 30.0, 300.0], [7.0, 70.0, 700.0]]), + ) + + v[1] = np.array([[8.0, 80.0, 800.0]]) + assert v._state["data"].shape[0] > v.total_rows + + v.compact() + assert v._state["data"].shape[0] == v.total_rows + + with pytest.raises(ValueError, match="exactly one cell"): + v.append_rows(slice(None), np.array([[1.0, 2.0, 3.0]])) + + def test_boolean_indexing_is_axis_wise(self): + v = make_grid_vector() + + rows = np.array([True, False, True]) + cols = np.array([False, True]) + selected = v[rows, cols] + + assert selected.shape == (2, 1) + np.testing.assert_array_equal(selected[0, 0].array, np.array([[1.0, 101.0, 201.0]])) + np.testing.assert_array_equal(selected[1, 0].array, np.array([[21.0, 121.0, 221.0]])) with pytest.raises(IndexError): - v.set_data([data1, data1], [3, 4], 0) # List index out of bounds + _ = v[np.array([[True, False], [False, True]])] - with pytest.raises(TypeError): - v.set_data([1, 2, 3], 0, 0) # Invalid data type # type: ignore + def test_empty_selection_is_valid_and_no_op_for_scalar_math(self): + v = make_grid_vector() + before = v.copy().flatten() - with pytest.raises(ValueError): - v.set_data(np.array([[1.0]]), 0, 0) # Wrong number of fields + empty = v[[], :] + assert empty.shape == (0, 2) + assert empty.flatten().shape == (0, 3) + + empty.select_fields("kx")[...] += 1 + np.testing.assert_array_equal(v.flatten(), before) + + def test_add_fields_defaults_expression_and_multiple_values(self): + v = make_line_vector() + + v.add_fields(("h", "k")) + assert v.fields == ["intensity", "kx", "ky", "h", "k"] + assert np.isnan(v[0].array[:, 3:5]).all() + + v.add_fields("field_out", v.select_fields("kx") + v.select_fields("ky")) + np.testing.assert_array_equal( + v.select_fields("field_out").flatten(), + np.array([[110.0], [220.0], [330.0], [440.0], [550.0], [660.0]]), + ) + + v2 = make_line_vector() + v2.add_fields(("h", "k"), (1.0, np.array([5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))) + np.testing.assert_array_equal(v2.select_fields("h").flatten(), np.ones((6, 1))) + np.testing.assert_array_equal( + v2.select_fields("k").flatten(), + np.array([[5.0], [6.0], [7.0], [8.0], [9.0], [10.0]]), + ) + + with pytest.raises(ValueError, match="all fields are selected"): + v2.select_fields("kx").add_fields("bad") + + def test_rename_fields(self): + v = make_line_vector() + kx_data = v.select_fields("kx").flatten().copy() + + v.rename_fields({"kx": "qx", "ky": "qy"}) + assert v.fields == ["intensity", "qx", "qy"] + np.testing.assert_array_equal(v.select_fields("qx").flatten(), kx_data) + + # Renaming through a field-selected view updates that view's selected names + view = v.select_fields("qx") + assert view.fields == ["qx"] + view.rename_fields({"qx": "px"}) + assert view.fields == ["px"] + assert v.fields == ["intensity", "px", "qy"] + + with pytest.raises(KeyError, match="Unknown field"): + v.rename_fields({"nonexistent": "x"}) + + with pytest.raises(ValueError, match="already exist"): + v.rename_fields({"px": "intensity"}) + + def test_remove_fields_preserves_remaining_data(self): + v = make_line_vector() + v.add_fields("extra", 1.0) + v.remove_fields(("kx", "extra")) + + assert v.fields == ["intensity", "ky"] + np.testing.assert_array_equal( + v[0].array, + np.array([[1.0, 100.0], [2.0, 200.0]]), + ) + + def test_copy_is_deep(self): + v = make_line_vector() + v_copy = v.select_fields(["intensity", "kx"]).copy() + + v_copy[0].array[0, 0] = -1.0 + assert v[0].array[0, 0] == 1.0 + assert v_copy.fields == ["intensity", "kx"] + assert v_copy.shape == (4,) + + def test_from_data_supports_nested_fixed_grid(self): + data = [ + [np.array([[1.0, 2.0]]), np.array([[3.0, 4.0], [5.0, 6.0]])], + [np.array([[7.0, 8.0]]), np.array([[9.0, 10.0]])], + ] + v = Vector.from_data(data=data, fields=["a", "b"], units=["u1", "u2"], name="nested") + + assert v.shape == (2, 2) + assert v.fields == ["a", "b"] + assert v.units == ["u1", "u2"] + assert v.name == "nested" + np.testing.assert_array_equal(v[0, 1].array, np.array([[3.0, 4.0], [5.0, 6.0]])) + + tuple_cells = [ + ([1.0, 2.0], [3.0, 4.0]), + ([5.0, 6.0], [7.0, 8.0], [9.0, 10.0]), + ] + tuple_vector = Vector.from_data(data=tuple_cells, fields=["a", "b"]) + assert tuple_vector.shape == (2,) + np.testing.assert_array_equal(tuple_vector[0].array, np.array([[1.0, 2.0], [3.0, 4.0]])) + np.testing.assert_array_equal( + tuple_vector[1].array, + np.array([[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]), + ) + + tuple_data = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]])) + tuple_outer = Vector.from_data(data=tuple_data, fields=["a", "b"]) + assert tuple_outer.shape == (2,) + + with pytest.raises(TypeError, match="Data must be a list or tuple"): + Vector.from_data(data=np.array([1, 2, 3])) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="same number of fields"): + Vector.from_data(data=[np.array([[1.0, 2.0]]), np.array([[1.0, 2.0, 3.0]])]) + + def test_save_and_load_round_trip(self, tmp_path): + v = make_grid_vector() + v.add_fields("extra", v.select_fields("intensity") + 1.0) + + path = tmp_path / "vector_test.zip" + v.save(path, mode="o", compression_level=4) + + with zipfile.ZipFile(path) as zf: + names = [info.filename for info in zf.infolist()] + assert len(names) < 30 + assert "_state/data/zarr.json" in names + assert all(not name.startswith("_selection_coords/") for name in names) + + loaded = load(path) + assert isinstance(loaded, Vector) + assert loaded.shape == v.shape + assert loaded.fields == v.fields + assert loaded.units == v.units + np.testing.assert_array_equal(loaded[2, 1].array, v[2, 1].array) diff --git a/tests/visualization/test_visualization_utils.py b/tests/visualization/test_visualization_utils.py index c3a1f048..ca114774 100644 --- a/tests/visualization/test_visualization_utils.py +++ b/tests/visualization/test_visualization_utils.py @@ -297,3 +297,11 @@ def test_bilinear_histogram_2d_with_custom_statistic(self): hist = bilinear_histogram_2d(shape, x, y, weight, statistic="mean") assert hist.shape == shape assert np.all(~np.isnan(hist[hist > 0])) # Only check non-zero values for NaN + + def test_bilinear_histogram_2d_accepts_column_vectors(self): + shape = (8, 8) + x = np.array([[1.0], [2.0], [3.0]]) + y = np.array([[1.5], [2.5], [3.5]]) + weight = np.array([[2.0], [3.0], [4.0]]) + hist = bilinear_histogram_2d(shape, x, y, weight) + assert hist.shape == shape