Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13994.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix crash in :func:`mne.preprocessing.annotate_muscle_zscore` when ``n_jobs='cuda'`` by adding ``n_jobs_hilbert`` for the Hilbert transform, by :newcontrib:`Krishnaveni Parvataneni`.
9 changes: 8 additions & 1 deletion mne/preprocessing/artifact_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def annotate_muscle_zscore(
min_length_good=0.1,
filter_freq=(110, 140),
n_jobs=None,
n_jobs_hilbert=None,
verbose=None,
):
"""Create annotations for segments that likely contain muscle artifacts.
Expand Down Expand Up @@ -76,6 +77,10 @@ def annotate_muscle_zscore(
The lower and upper frequencies of the band-pass filter.
Default is ``(110, 140)``.
%(n_jobs)s
n_jobs_hilbert : int | None
Number of jobs for the Hilbert transform. Cannot be ``'cuda'``.
Defaults to ``1`` when ``None``. Use when ``n_jobs='cuda'`` to keep
filtering on GPU while Hilbert runs on CPU.
%(verbose)s

Returns
Expand Down Expand Up @@ -116,7 +121,9 @@ def annotate_muscle_zscore(
pad="reflect_limited",
n_jobs=n_jobs,
)
raw_copy.apply_hilbert(envelope=True, n_jobs=n_jobs)
raw_copy.apply_hilbert(
envelope=True, n_jobs=n_jobs_hilbert if n_jobs_hilbert is not None else 1
)
Comment on lines +124 to +126

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... this isn't backward compatible. Previously, if a user provided n_jobs=4 for example, this would use 4 jobs. So a better pattern I think is:

if n_jobs != "cuda" and n_jobs_hilbert is None:
    n_jobs_hilbert = n_jobs

And you say in the docstring that n_jobs_hilbert will default to the value of n_jobs when n_jobs != "cuda".


data = raw_copy.get_data(reject_by_annotation="NaN")
nan_mask = ~np.isnan(data[0])
Expand Down
28 changes: 28 additions & 0 deletions mne/preprocessing/tests/test_artifact_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,34 @@ def test_muscle_annotation_without_meeg_data(meas_date):
annotate_muscle_zscore(raw, threshold=10)


@testing.requires_testing_data
def test_muscle_annotation_n_jobs_cuda():
"""Test annotate_muscle_zscore with separate n_jobs for filter and Hilbert."""
raw = read_raw_fif(raw_fname, allow_maxshield="yes").load_data()
raw.notch_filter([50, 110, 150])

annot_cuda, scores_cuda = annotate_muscle_zscore(
raw, ch_type="mag", threshold=10, n_jobs="cuda", n_jobs_hilbert=1
)
assert annot_cuda.duration.size == 2
assert scores_cuda.shape == (raw.n_times,)

annot_default, scores_default = annotate_muscle_zscore(
raw, ch_type="mag", threshold=10
)
annot_int, scores_int = annotate_muscle_zscore(
raw, ch_type="mag", threshold=10, n_jobs=1, n_jobs_hilbert=1
)
assert annot_int.duration.size == annot_default.duration.size
assert_array_equal(scores_int, scores_default)

annot_jobs, scores_jobs = annotate_muscle_zscore(
raw, ch_type="mag", threshold=10, n_jobs=4, n_jobs_hilbert=2
)
assert annot_jobs.duration.size == annot_default.duration.size
assert_array_equal(scores_jobs, scores_default)


@pytest.mark.parametrize("meas_date", (None, "orig"))
@testing.requires_testing_data
def test_annotate_breaks(meas_date):
Expand Down