diff --git a/doc/changes/dev/14003.bugfix.rst b/doc/changes/dev/14003.bugfix.rst new file mode 100644 index 00000000000..7f418237247 --- /dev/null +++ b/doc/changes/dev/14003.bugfix.rst @@ -0,0 +1 @@ +Fix :func:`mne.time_frequency.psd_array_welch` (and Welch-method ``compute_psd``) so that good data spans shorter than ``n_per_seg`` no longer raise ``noverlap must be less than nperseg``; such spans are now dropped from the estimate with a warning, by :newcontrib:`Cedric Conday`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 2bf4fc9fdd7..9a797c38834 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -52,6 +52,7 @@ .. _Carina Forster: https://github.com/CarinaFo .. _Carlos de la Torre-Ortiz: https://github.com/c-torre .. _Cathy Nangini: https://github.com/KatiRG +.. _Cedric Conday: https://github.com/CedricConday .. _Chetan Gohil: https://github.com/cgohil8 .. _Chris Bailey: https://github.com/cjayb .. _Chris Holdgraf: https://chrisholdgraf.com diff --git a/mne/time_frequency/psd.py b/mne/time_frequency/psd.py index 01d932699a1..643233b2c5e 100644 --- a/mne/time_frequency/psd.py +++ b/mne/time_frequency/psd.py @@ -2,7 +2,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import warnings from functools import partial import numpy as np @@ -10,7 +9,7 @@ from ..fixes import _reshape_view from ..parallel import parallel_func -from ..utils import _check_option, _ensure_int, logger, verbose, warn +from ..utils import _check_option, _ensure_int, _pl, logger, verbose, warn from ..utils.numerics import _mask_to_onsets_offsets @@ -258,15 +257,29 @@ def psd_array_welch( # Aligned NaNs across channels → treat as bad annotations. good_mask = ~nan_mask_full t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0]) - x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)] - # weights reflect the number of samples used from each span. For spans longer - # than `n_per_seg`, trailing samples may be discarded. For spans shorter than - # `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically - # reduces `n_per_seg` to match the span length (with a warning). + all_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)] + # Drop good data spans shorter than n_per_seg: a single Welch window does not + # fit them. (Shrinking the window per-span would mix incompatible estimates, + # and passing them to SciPy as-is raises "noverlap must be less than + # nperseg".) Warn so the user can lower n_per_seg to keep them. See #13039. + x_splits = [span for span in all_splits if span.shape[-1] >= n_per_seg] + n_dropped = len(all_splits) - len(x_splits) + if n_dropped: + warn( + f"{n_dropped} good data span{_pl(n_dropped)} shorter than n_per_seg " + f"({n_per_seg}) {'was' if n_dropped == 1 else 'were'} excluded from " + "the PSD estimate; reduce n_per_seg (or n_fft) to include them." + ) + if not x_splits: + raise ValueError( + f"All good data spans are shorter than n_per_seg ({n_per_seg}); no " + "data is left to compute the PSD. Reduce n_per_seg (or n_fft)." + ) + # weights reflect the number of samples used from each (kept) span; trailing + # samples beyond the last full window are discarded. step = n_per_seg - n_overlap - span_lengths = [span.shape[-1] for span in x_splits] weights = [ - w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths + w - ((w - n_overlap) % step) for w in (s.shape[-1] for s in x_splits) ] agg_func = partial(np.average, weights=weights) if n_jobs > 1: @@ -274,22 +287,7 @@ def psd_array_welch( f"Data split into {len(x_splits)} (probably unequal) chunks due to " '"bad_*" annotations. Parallelization may be sub-optimal.' ) - if (np.array(span_lengths) < n_per_seg).any(): - logger.info( - "At least one good data span is shorter than n_per_seg, and will be " - "analyzed with a shorter window than the rest of the file." - ) - - def func(*args, **kwargs): - # swallow SciPy warnings caused by short good data spans - with warnings.catch_warnings(): - warnings.filterwarnings( - action="ignore", - module="scipy", - category=UserWarning, - message=r"nperseg = \d+ is greater than input length", - ) - return _func(*args, **kwargs) + func = _func else: # Either no NaNs, or NaNs are not aligned across channels. diff --git a/mne/time_frequency/tests/test_psd.py b/mne/time_frequency/tests/test_psd.py index 9718f80b153..743e2fd8b0d 100644 --- a/mne/time_frequency/tests/test_psd.py +++ b/mne/time_frequency/tests/test_psd.py @@ -58,6 +58,34 @@ def test_bad_annot_handling(): np.testing.assert_allclose(got[0], want[0], rtol=1e-15, atol=0) +def test_psd_welch_short_span_dropped(): + """Good spans shorter than n_per_seg are dropped with a warning (gh-13039).""" + n_fft = 256 + n_overlap = n_fft // 2 # 128 + n_chan = 2 + rng = np.random.default_rng(0) + # A short good span (100 samples < n_per_seg), then a bad-annotation + # (aligned NaN), then a long span. The short span cannot hold a single + # Welch window; it is now dropped with a warning rather than raising from + # SciPy ("noverlap must be less than nperseg"). + short = rng.standard_normal((n_chan, 100)) + long = rng.standard_normal((n_chan, 5 * n_fft)) + x = np.concatenate((short, np.full((n_chan, 1), np.nan), long), axis=-1) + with pytest.warns(RuntimeWarning, match="shorter than n_per_seg"): + psds, freqs = psd_array_welch(x, sfreq=100, n_fft=n_fft, n_overlap=n_overlap) + assert psds.shape == (n_chan, len(freqs)) + assert np.all(np.isfinite(psds)) + + # If *every* good span is too short, there is nothing left to analyze. Use + # three short spans so the total length still exceeds n_fft (otherwise the + # earlier n_fft > n_times guard fires first). + nan_col = np.full((n_chan, 1), np.nan) + x_all_short = np.concatenate((short, nan_col, short, nan_col, short), axis=-1) + with pytest.raises(ValueError, match="All good data spans are shorter"): + with pytest.warns(RuntimeWarning, match="shorter than n_per_seg"): + psd_array_welch(x_all_short, sfreq=100, n_fft=n_fft, n_overlap=n_overlap) + + def _make_psd_data(): """Make noise data with sinusoids in 2 out of 7 channels.""" rng = np.random.default_rng(0)