diff --git a/doc/changes/dev/13994.bugfix.rst b/doc/changes/dev/13994.bugfix.rst new file mode 100644 index 00000000000..ca1e4537be7 --- /dev/null +++ b/doc/changes/dev/13994.bugfix.rst @@ -0,0 +1 @@ +Fix crash in :func:`mne.preprocessing.annotate_muscle_zscore` when ``n_jobs='cuda'`` by adding ``n_jobs_hilbert`` for the Hilbert transform, by :newcontrib:`Krishnaveni Parvataneni`. diff --git a/mne/preprocessing/artifact_detection.py b/mne/preprocessing/artifact_detection.py index 8674d6e22b3..c87934c25f8 100644 --- a/mne/preprocessing/artifact_detection.py +++ b/mne/preprocessing/artifact_detection.py @@ -43,6 +43,7 @@ def annotate_muscle_zscore( min_length_good=0.1, filter_freq=(110, 140), n_jobs=None, + n_jobs_hilbert=None, verbose=None, ): """Create annotations for segments that likely contain muscle artifacts. @@ -76,6 +77,10 @@ def annotate_muscle_zscore( The lower and upper frequencies of the band-pass filter. Default is ``(110, 140)``. %(n_jobs)s + n_jobs_hilbert : int | None + Number of jobs for the Hilbert transform. Cannot be ``'cuda'``. + Defaults to ``1`` when ``None``. Use when ``n_jobs='cuda'`` to keep + filtering on GPU while Hilbert runs on CPU. %(verbose)s Returns @@ -116,7 +121,9 @@ def annotate_muscle_zscore( pad="reflect_limited", n_jobs=n_jobs, ) - raw_copy.apply_hilbert(envelope=True, n_jobs=n_jobs) + raw_copy.apply_hilbert( + envelope=True, n_jobs=n_jobs_hilbert if n_jobs_hilbert is not None else 1 + ) data = raw_copy.get_data(reject_by_annotation="NaN") nan_mask = ~np.isnan(data[0]) diff --git a/mne/preprocessing/tests/test_artifact_detection.py b/mne/preprocessing/tests/test_artifact_detection.py index ccd8893ba11..dea4b9eccb4 100644 --- a/mne/preprocessing/tests/test_artifact_detection.py +++ b/mne/preprocessing/tests/test_artifact_detection.py @@ -190,6 +190,34 @@ def test_muscle_annotation_without_meeg_data(meas_date): annotate_muscle_zscore(raw, threshold=10) +@testing.requires_testing_data +def test_muscle_annotation_n_jobs_cuda(): + """Test annotate_muscle_zscore with separate n_jobs for filter and Hilbert.""" + raw = read_raw_fif(raw_fname, allow_maxshield="yes").load_data() + raw.notch_filter([50, 110, 150]) + + annot_cuda, scores_cuda = annotate_muscle_zscore( + raw, ch_type="mag", threshold=10, n_jobs="cuda", n_jobs_hilbert=1 + ) + assert annot_cuda.duration.size == 2 + assert scores_cuda.shape == (raw.n_times,) + + annot_default, scores_default = annotate_muscle_zscore( + raw, ch_type="mag", threshold=10 + ) + annot_int, scores_int = annotate_muscle_zscore( + raw, ch_type="mag", threshold=10, n_jobs=1, n_jobs_hilbert=1 + ) + assert annot_int.duration.size == annot_default.duration.size + assert_array_equal(scores_int, scores_default) + + annot_jobs, scores_jobs = annotate_muscle_zscore( + raw, ch_type="mag", threshold=10, n_jobs=4, n_jobs_hilbert=2 + ) + assert annot_jobs.duration.size == annot_default.duration.size + assert_array_equal(scores_jobs, scores_default) + + @pytest.mark.parametrize("meas_date", (None, "orig")) @testing.requires_testing_data def test_annotate_breaks(meas_date):