Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/14003.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix :func:`mne.time_frequency.psd_array_welch` (and Welch-method ``compute_psd``) so that good data spans shorter than ``n_per_seg`` no longer raise ``noverlap must be less than nperseg``; such spans are now dropped from the estimate with a warning, by :newcontrib:`Cedric Conday`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
.. _Carina Forster: https://github.com/CarinaFo
.. _Carlos de la Torre-Ortiz: https://github.com/c-torre
.. _Cathy Nangini: https://github.com/KatiRG
.. _Cedric Conday: https://github.com/CedricConday
.. _Chetan Gohil: https://github.com/cgohil8
.. _Chris Bailey: https://github.com/cjayb
.. _Chris Holdgraf: https://chrisholdgraf.com
Expand Down
48 changes: 23 additions & 25 deletions mne/time_frequency/psd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import warnings
from functools import partial

import numpy as np
from scipy.signal import spectrogram

from ..fixes import _reshape_view
from ..parallel import parallel_func
from ..utils import _check_option, _ensure_int, logger, verbose, warn
from ..utils import _check_option, _ensure_int, _pl, logger, verbose, warn
from ..utils.numerics import _mask_to_onsets_offsets


Expand Down Expand Up @@ -258,38 +257,37 @@ def psd_array_welch(
# Aligned NaNs across channels → treat as bad annotations.
good_mask = ~nan_mask_full
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
# weights reflect the number of samples used from each span. For spans longer
# than `n_per_seg`, trailing samples may be discarded. For spans shorter than
# `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically
# reduces `n_per_seg` to match the span length (with a warning).
all_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
# Drop good data spans shorter than n_per_seg: a single Welch window does not
# fit them. (Shrinking the window per-span would mix incompatible estimates,
# and passing them to SciPy as-is raises "noverlap must be less than
# nperseg".) Warn so the user can lower n_per_seg to keep them. See #13039.
x_splits = [span for span in all_splits if span.shape[-1] >= n_per_seg]
n_dropped = len(all_splits) - len(x_splits)
if n_dropped:
warn(
f"{n_dropped} good data span{_pl(n_dropped)} shorter than n_per_seg "
f"({n_per_seg}) {'was' if n_dropped == 1 else 'were'} excluded from "
"the PSD estimate; reduce n_per_seg (or n_fft) to include them."
)
if not x_splits:
raise ValueError(
f"All good data spans are shorter than n_per_seg ({n_per_seg}); no "
"data is left to compute the PSD. Reduce n_per_seg (or n_fft)."
)
# weights reflect the number of samples used from each (kept) span; trailing
# samples beyond the last full window are discarded.
step = n_per_seg - n_overlap
span_lengths = [span.shape[-1] for span in x_splits]
weights = [
w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths
w - ((w - n_overlap) % step) for w in (s.shape[-1] for s in x_splits)
]
agg_func = partial(np.average, weights=weights)
if n_jobs > 1:
logger.info(
f"Data split into {len(x_splits)} (probably unequal) chunks due to "
'"bad_*" annotations. Parallelization may be sub-optimal.'
)
if (np.array(span_lengths) < n_per_seg).any():
logger.info(
"At least one good data span is shorter than n_per_seg, and will be "
"analyzed with a shorter window than the rest of the file."
)

def func(*args, **kwargs):
# swallow SciPy warnings caused by short good data spans
with warnings.catch_warnings():
warnings.filterwarnings(
action="ignore",
module="scipy",
category=UserWarning,
message=r"nperseg = \d+ is greater than input length",
)
return _func(*args, **kwargs)
func = _func

else:
# Either no NaNs, or NaNs are not aligned across channels.
Expand Down
28 changes: 28 additions & 0 deletions mne/time_frequency/tests/test_psd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@ def test_bad_annot_handling():
np.testing.assert_allclose(got[0], want[0], rtol=1e-15, atol=0)


def test_psd_welch_short_span_dropped():
"""Good spans shorter than n_per_seg are dropped with a warning (gh-13039)."""
n_fft = 256
n_overlap = n_fft // 2 # 128
n_chan = 2
rng = np.random.default_rng(0)
# A short good span (100 samples < n_per_seg), then a bad-annotation
# (aligned NaN), then a long span. The short span cannot hold a single
# Welch window; it is now dropped with a warning rather than raising from
# SciPy ("noverlap must be less than nperseg").
short = rng.standard_normal((n_chan, 100))
long = rng.standard_normal((n_chan, 5 * n_fft))
x = np.concatenate((short, np.full((n_chan, 1), np.nan), long), axis=-1)
with pytest.warns(RuntimeWarning, match="shorter than n_per_seg"):
psds, freqs = psd_array_welch(x, sfreq=100, n_fft=n_fft, n_overlap=n_overlap)
assert psds.shape == (n_chan, len(freqs))
assert np.all(np.isfinite(psds))

# If *every* good span is too short, there is nothing left to analyze. Use
# three short spans so the total length still exceeds n_fft (otherwise the
# earlier n_fft > n_times guard fires first).
nan_col = np.full((n_chan, 1), np.nan)
x_all_short = np.concatenate((short, nan_col, short, nan_col, short), axis=-1)
with pytest.raises(ValueError, match="All good data spans are shorter"):
with pytest.warns(RuntimeWarning, match="shorter than n_per_seg"):
psd_array_welch(x_all_short, sfreq=100, n_fft=n_fft, n_overlap=n_overlap)


def _make_psd_data():
"""Make noise data with sinusoids in 2 out of 7 channels."""
rng = np.random.default_rng(0)
Expand Down
Loading