From ea4ca2c8f1958ada2cb203715b45e3480f037c2a Mon Sep 17 00:00:00 2001 From: payam Date: Sun, 28 Jun 2026 11:26:30 +0200 Subject: [PATCH 01/12] FIX: use direct dict access for hemi_data['key'] --- mne/viz/_brain/_brain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index b66bdf5c436..e8255750a72 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -3442,7 +3442,7 @@ def _update_colormap_range(self, fmin=None, fmid=None, fmax=None, alpha=None): if hemi in self.layered_meshes: mesh = self.layered_meshes[hemi] mesh.update_overlay( - name=hemi_data.get("key", "data"), + name=hemi_data["key"], colormap=self._data["ctable"], opacity=alpha, rng=rng, From 8c012fc4cb5ce2f236a555696f23fe3dea298193 Mon Sep 17 00:00:00 2001 From: payam Date: Sun, 28 Jun 2026 11:27:04 +0200 Subject: [PATCH 02/12] TST: add pytest.raises test for LayeredMesh.update_overlay shape check from previous PR --- mne/viz/_brain/tests/test_brain.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index d50495b14ab..b60f1d1fa73 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -397,6 +397,8 @@ def __init__(self): assert isinstance(lm, LayeredMesh) lm.update_overlay(name="data", rng=[fmin, fmax]) lm.update() + with pytest.raises(ValueError, match="must have shape"): + lm.update_overlay(name="data", scalars=np.ones(1)) brain.remove_data() assert "data" not in brain._actors assert "time_change" not in ui_events._get_event_channel(brain) From 2cf57a45232fdce2bd0de6aa438454e176f4f088 Mon Sep 17 00:00:00 2001 From: payam Date: Wed, 1 Jul 2026 08:36:33 +0200 Subject: [PATCH 03/12] ENH: Support overlays at the same time in Brain.add_data --- mne/viz/_brain/_brain.py | 372 ++++++++++++++++++++++----------------- 1 file changed, 207 insertions(+), 165 deletions(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index e8255750a72..67c86648990 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -378,8 +378,10 @@ def __init__( # for now only one time label can be added # since it is the same for all figures self._time_label_added = False - # array of data used by TimeViewer + # array of data used by TimeViewer; keyed by overlay name self._data = {} + self._active_data_key = None + self._foci_data = {} self.geo = {} self.set_time_interpolation("nearest") @@ -713,7 +715,7 @@ def set_playback_speed(self, speed): publish(self, PlaybackSpeed(speed=speed)) def _configure_time_label(self): - self.time_actor = self._data.get("time_actor") + self.time_actor = self._active_data.get("time_actor") if self.time_actor is not None: self.time_actor.SetPosition(0.5, 0.03) self.time_actor.GetTextProperty().SetJustificationToCentered() @@ -727,7 +729,7 @@ def _configure_scalar_bar(self): self._scalar_bar.SetPosition(0.02, 0.2) def _configure_dock_playback_widget(self, name): - len_time = len(self._data["time"]) - 1 + len_time = len(self._active_data["time"]) - 1 # Time widget if len_time < 1: @@ -744,7 +746,7 @@ def current_time_func(): self._renderer._enable_time_interaction( self, current_time_func, - self._data["time"], + self._active_data["time"], self.default_playback_speed_value, self.default_playback_speed_range, ) @@ -752,7 +754,7 @@ def current_time_func(): # Time label current_time = self._current_time assert current_time is not None # should never be the case, float - time_label = self._data["time_label"] + time_label = self._active_data["time_label"] if callable(time_label): current_time = time_label(current_time) else: @@ -821,7 +823,7 @@ def set_orientation(value, orientation_data=orientation_data): def _configure_dock_colormap_widget(self, name): fmax, fscale, fscale_power = _get_range(self) rng = [0, fmax * fscale] - self._data["fscale"] = fscale + self._active_data["fscale"] = fscale layout = self._renderer._dock_add_group_box(name) text = "min / mid / max" @@ -836,14 +838,14 @@ def _configure_dock_colormap_widget(self, name): @_auto_weakref def update_single_lut_value(value, key): # Called by the sliders and spin boxes. - self.update_lut(**{key: value / self._data["fscale"]}) + self.update_lut(**{key: value / self._active_data["fscale"]}) keys = ("fmin", "fmid", "fmax") for key in keys: hlayout = self._renderer._dock_add_layout(vertical=False) self.widgets[key] = self._renderer._dock_add_slider( name=None, - value=self._data[key] * self._data["fscale"], + value=self._active_data[key] * self._active_data["fscale"], rng=rng, callback=partial(update_single_lut_value, key=key), double=True, @@ -851,7 +853,7 @@ def update_single_lut_value(value, key): ) self.widgets[f"entry_{key}"] = self._renderer._dock_add_spin_box( name=None, - value=self._data[key] * self._data["fscale"], + value=self._active_data[key] * self._active_data["fscale"], callback=partial(update_single_lut_value, key=key), rng=rng, layout=hlayout, @@ -900,8 +902,8 @@ def _configure_dock_trace_widget(self, name): return # do not show trace mode for volumes if ( - self._data.get("src", None) is not None - and self._data["src"].kind == "volume" + self._active_data.get("src", None) is not None + and self._active_data["src"].kind == "volume" ): self._configure_vertex_time_course() return @@ -949,9 +951,9 @@ def _set_label_mode(mode): cands = _read_annot_cands(dir_name, raise_error=False) cands = cands + ["None"] self.annot = cands[0] - stc = self._data["stc"] + stc = self._active_data["stc"] modes = _get_allowed_label_modes(stc) - if self._data["src"] is None: + if self._active_data["src"] is None: modes = [ m for m in modes if m not in self.default_label_extract_modes["src"] ] @@ -985,7 +987,7 @@ def _configure_dock(self): # Smoothing widget self.widgets["smoothing"] = self._renderer._dock_add_spin_box( name="Smoothing", - value=self._data["smoothing_steps"], + value=self._active_data["smoothing_steps"], rng=self.default_smoothing_range, callback=self.set_data_smoothing, double=False, @@ -1001,7 +1003,7 @@ def _configure_mplcanvas(self): show_traces=self.show_traces, separate_canvas=self.separate_canvas, ) - xlim = [np.min(self._data["time"]), np.max(self._data["time"])] + xlim = [np.min(self._active_data["time"]), np.max(self._active_data["time"])] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) self.mpl_canvas.axes.set(xlim=xlim) @@ -1028,7 +1030,7 @@ def _configure_vertex_time_course(self): del y (self.rms,) = self.mpl_canvas.axes.plot( - self._data["time"], + self._active_data["time"], rms, lw=3, label="RMS", @@ -1046,7 +1048,7 @@ def _configure_vertex_time_course(self): act_data = self.act_data_smooth.get(hemi, [None])[0] if act_data is None: continue - hemi_data = self._data[hemi] + hemi_data = self._active_data[hemi] vertices = hemi_data["vertices"] # simulate a picked renderer @@ -1055,9 +1057,9 @@ def _configure_vertex_time_course(self): self._picked_renderer = self._renderer._all_renderers[idx] # initialize the default point - if self._data["initial_time"] is not None: + if self._active_data["initial_time"] is not None: # pick at that time - use_data = act_data[:, [np.round(self._data["time_idx"]).astype(int)]] + use_data = act_data[:, [np.round(self._active_data["time_idx"]).astype(int)]] else: use_data = act_data ind = np.unravel_index( @@ -1069,7 +1071,7 @@ def _configure_vertex_time_course(self): def _configure_picking(self): # get data for each hemi for idx, hemi in enumerate(["vol", "lh", "rh"]): - hemi_data = self._data.get(hemi) + hemi_data = self._active_data.get(hemi) if hemi_data is not None: act_data = hemi_data["array"] if act_data.ndim == 3: @@ -1247,9 +1249,9 @@ def _on_pick(self, vtk_picker, event): # camera-to-click array, which fortunately we can get "just" # by inspecting the points that are sufficiently close to the # ray. - grid = self._data[hemi]["grid"] - vertices = self._data[hemi]["vertices"] - coords = self._data[hemi]["grid_coords"][vertices] + grid = self._active_data[hemi]["grid"] + vertices = self._active_data[hemi]["vertices"] + coords = self._active_data[hemi]["grid_coords"][vertices] scalars = grid.point_data["values"][vertices] spacing = np.array(grid.GetSpacing()) max_dist = np.max(spacing) / 2.0 @@ -1307,17 +1309,17 @@ def _on_colormap_range(self, event): return lims = {key: getattr(event, key) for key in ("fmin", "fmid", "fmax", "alpha")} # Check if limits have changed at all. - if all(val is None or val == self._data[key] for key, val in lims.items()): + if all(val is None or val == self._active_data[key] for key, val in lims.items()): return # Update the GUI elements. with disable_ui_events(self): for key, val in lims.items(): if val is not None: if key in self.widgets: - self.widgets[key].set_value(val * self._data["fscale"]) + self.widgets[key].set_value(val * self._active_data["fscale"]) entry_key = "entry_" + key if entry_key in self.widgets: - self.widgets[entry_key].set_value(val * self._data["fscale"]) + self.widgets[entry_key].set_value(val * self._active_data["fscale"]) # Update the render. self._update_colormap_range(**lims) @@ -1325,7 +1327,7 @@ def _on_vertex_select(self, event): """Respond to vertex_select UI event.""" if event.hemi == "vol": try: - mesh = self._data[event.hemi]["grid"] + mesh = self._active_data[event.hemi]["grid"] except KeyError: return else: @@ -1479,15 +1481,15 @@ def plot_time_course(self, hemi, vertex_id, color, update=True): """ if self.mpl_canvas is None: return - time = self._data["time"].copy() # avoid circular ref + time = self._active_data["time"].copy() # avoid circular ref mni = None if hemi == "vol": hemi_str = "V" xfm = read_talxfm(self._subject, self._subjects_dir) if self._units == "mm": xfm["trans"][:3, 3] *= 1000.0 - ijk = np.unravel_index(vertex_id, self._data[hemi]["grid_shape"], order="F") - src_mri_t = self._data[hemi]["grid_src_mri_t"] + ijk = np.unravel_index(vertex_id, self._active_data[hemi]["grid_shape"], order="F") + src_mri_t = self._active_data[hemi]["grid_src_mri_t"] mni = apply_trans(xfm["trans"] @ src_mri_t, ijk) else: hemi_str = "L" if hemi == "lh" else "R" @@ -1746,10 +1748,14 @@ def add_data( If None, it is assumed to belong to the hemisphere being shown. If two hemispheres are being shown, an error will be thrown. - remove_existing : bool - Not supported yet. - Remove surface added by previous "add_data" call. Useful for - conserving memory when displaying different data in a loop. + remove_existing : bool | None + If ``True`` or ``None`` (default), the overlay added by the + previous :meth:`add_data` call is removed before the new one is + rendered. Set to ``False`` to keep existing overlays and composite + the new one on top (requires distinct ``key`` values). + + .. versionchanged:: 1.13 + Now accepts ``True`` and ``False`` in addition to ``None``. time_label_size : int Font size of the time label (default 14). initial_time : float | None @@ -1792,7 +1798,7 @@ def add_data( # those parameters are not supported yet, only None is allowed _check_option("thresh", thresh, [None]) - _check_option("remove_existing", remove_existing, [None]) + _check_option("remove_existing", remove_existing, [None, True, False]) _validate_type(time_label_size, (None, "numeric"), "time_label_size") if time_label_size is not None: time_label_size = float(time_label_size) @@ -1805,8 +1811,19 @@ def add_data( stc, array, vertices = self._check_stc(hemi, array, vertices) array = np.asarray(array) vector_alpha = alpha if vector_alpha is None else vector_alpha - self._data["vector_alpha"] = vector_alpha - self._data["scale_factor"] = scale_factor + + # Remove the previously active overlay when not keeping existing ones. + old_key = self._active_data_key + if remove_existing is not False and old_key is not None and old_key != key: + for _hemi in list(self.layered_meshes): + self.layered_meshes[_hemi].remove_overlay(old_key) + del self._data[old_key] + + # Index all per-dataset state by key so multiple overlays can coexist. + self._active_data_key = key + self._data.setdefault(key, {}) + self._data[key]["vector_alpha"] = vector_alpha + self._data[key]["scale_factor"] = scale_factor # Create time array and add label if > 1D if array.ndim <= 1: @@ -1822,7 +1839,7 @@ def add_data( f"time has shape {time.shape}, but need shape " f"{(array.shape[-1],)} (array.shape[-1])" ) - self._data["time"] = time + self._data[key]["time"] = time if self._n_times is None: self._times = time @@ -1869,30 +1886,29 @@ def add_data( f"{type(smoothing_steps)} was given." ) - self._data["stc"] = stc - self._data["src"] = src - self._data["smoothing_steps"] = smoothing_steps - self._data["clim"] = clim - self._data["time"] = time - self._data["initial_time"] = initial_time - self._data["time_label"] = time_label - self._data["initial_time_idx"] = time_idx - self._data["time_idx"] = time_idx - self._data["transparent"] = transparent + self._data[key]["stc"] = stc + self._data[key]["src"] = src + self._data[key]["smoothing_steps"] = smoothing_steps + self._data[key]["clim"] = clim + self._data[key]["time"] = time + self._data[key]["initial_time"] = initial_time + self._data[key]["time_label"] = time_label + self._data[key]["initial_time_idx"] = time_idx + self._data[key]["time_idx"] = time_idx + self._data[key]["transparent"] = transparent # data specific for a hemi - self._data[hemi] = dict() - self._data[hemi]["glyph_dataset"] = None - self._data[hemi]["glyph_mapper"] = None - self._data[hemi]["glyph_actor"] = None - self._data[hemi]["array"] = array - self._data[hemi]["vertices"] = vertices - self._data[hemi]["key"] = key - self._data["alpha"] = alpha - self._data["colormap"] = colormap - self._data["center"] = center - self._data["fmin"] = fmin - self._data["fmid"] = fmid - self._data["fmax"] = fmax + self._data[key][hemi] = dict() + self._data[key][hemi]["glyph_dataset"] = None + self._data[key][hemi]["glyph_mapper"] = None + self._data[key][hemi]["glyph_actor"] = None + self._data[key][hemi]["array"] = array + self._data[key][hemi]["vertices"] = vertices + self._data[key]["alpha"] = alpha + self._data[key]["colormap"] = colormap + self._data[key]["center"] = center + self._data[key]["fmin"] = fmin + self._data[key]["fmid"] = fmid + self._data[key]["fmax"] = fmax self._update_colormap_range() # 1) add the surfaces first @@ -1910,7 +1926,7 @@ def add_data( # set_data_smoothing calls "_update_current_time_idx" for us, which will set # _current_time self.set_time_interpolation(self.time_interpolation) - self.set_data_smoothing(self._data["smoothing_steps"]) + self.set_data_smoothing(self._data[key]["smoothing_steps"]) # 3) add the other actors if colorbar is True: @@ -1928,7 +1944,7 @@ def add_data( text=time_label(self._current_time), justification="right", ) - self._data["time_actor"] = time_actor + self._data[key]["time_actor"] = time_actor self._time_label_added = True if colorbar and self._scalar_bar is None and do: kwargs = dict( @@ -1951,7 +1967,13 @@ def add_data( def remove_data(self): """Remove rendered data from the mesh.""" - self._remove("data", render=True) + for key in list(self._data): + for hemi in list(self.layered_meshes): + self.layered_meshes[hemi].remove_overlay(key) + self._remove(key, render=False) + del self._data[key] + self._active_data_key = None + self._renderer._update() # Stop listening to events if "time_change" in _get_event_channel(self): @@ -2036,7 +2058,7 @@ def _add_volume_data(self, hemi, src, volume_options): ) alpha = volume_options["alpha"] if alpha is None: - alpha = 0.4 if self._data[hemi]["array"].ndim == 3 else 1.0 + alpha = 0.4 if self._active_data[hemi]["array"].ndim == 3 else 1.0 alpha = np.clip(float(alpha), 0.0, 1.0) resolution = volume_options["resolution"] surface_alpha = volume_options["surface_alpha"] @@ -2047,9 +2069,9 @@ def _add_volume_data(self, hemi, src, volume_options): silhouette_alpha = surface_alpha / 4.0 silhouette_linewidth = volume_options["silhouette_linewidth"] del volume_options - volume_pos = self._data[hemi].get("grid_volume_pos") - volume_neg = self._data[hemi].get("grid_volume_neg") - center = self._data["center"] + volume_pos = self._active_data[hemi].get("grid_volume_pos") + volume_neg = self._active_data[hemi].get("grid_volume_neg") + center = self._active_data["center"] if volume_pos is None: xyz = np.meshgrid(*[np.arange(s) for s in src[0]["shape"]], indexing="ij") dimensions = np.array(src[0]["shape"], int) @@ -2062,8 +2084,8 @@ def _add_volume_data(self, hemi, src, volume_options): coords = np.array([c.ravel(order="F") for c in xyz]).T coords = apply_trans(src_mri_t, coords) self.geo[hemi] = Bunch(coords=coords) - vertices = self._data[hemi]["vertices"] - assert self._data[hemi]["array"].shape[0] == len(vertices) + vertices = self._active_data[hemi]["vertices"] + assert self._active_data[hemi]["array"].shape[0] == len(vertices) # MNE constructs the source space on a uniform grid in MRI space, # but mne coreg can change it to be non-uniform, so we need to # use all three elements here @@ -2083,14 +2105,14 @@ def _add_volume_data(self, hemi, src, volume_options): center, interpolation, ) - self._data[hemi]["alpha"] = alpha # incorrectly set earlier - self._data[hemi]["grid"] = grid - self._data[hemi]["grid_mesh"] = grid_mesh - self._data[hemi]["grid_coords"] = coords - self._data[hemi]["grid_src_mri_t"] = src_mri_t - self._data[hemi]["grid_shape"] = dimensions - self._data[hemi]["grid_volume_pos"] = volume_pos - self._data[hemi]["grid_volume_neg"] = volume_neg + self._active_data[hemi]["alpha"] = alpha # incorrectly set earlier + self._active_data[hemi]["grid"] = grid + self._active_data[hemi]["grid_mesh"] = grid_mesh + self._active_data[hemi]["grid_coords"] = coords + self._active_data[hemi]["grid_src_mri_t"] = src_mri_t + self._active_data[hemi]["grid_shape"] = dimensions + self._active_data[hemi]["grid_volume_pos"] = volume_pos + self._active_data[hemi]["grid_volume_neg"] = volume_neg actor_pos, _ = self._renderer.plotter.add_actor( volume_pos, name=None, culling=False, reset_camera=False, render=False ) @@ -2099,7 +2121,7 @@ def _add_volume_data(self, hemi, src, volume_options): actor_neg, _ = self._renderer.plotter.add_actor( volume_neg, name=None, culling=False, reset_camera=False, render=False ) - grid_mesh = self._data[hemi]["grid_mesh"] + grid_mesh = self._active_data[hemi]["grid_mesh"] if grid_mesh is not None: actor_mesh, prop = self._renderer.plotter.add_actor( grid_mesh, @@ -2230,15 +2252,15 @@ def add_label( ids = ids[scalars >= scalar_thresh] if self.time_viewer and self.show_traces and self.traces_mode == "label": - stc = self._data["stc"] - src = self._data["src"] + stc = self._active_data["stc"] + src = self._active_data["src"] tc = stc.extract_label_time_course( label, src=src, mode=self.label_extract_mode ) tc = tc[0] if tc.ndim == 2 else tc[0, 0, :] color = next(self.color_cycle) line = self.mpl_canvas.plot( - self._data["time"], tc, label=label_name, color=color + self._active_data["time"], tc, label=label_name, color=color ) else: line = None @@ -2673,12 +2695,11 @@ def add_foci( self._set_camera(**views_dicts[hemi][v]) self._renderer._update() - # Store the foci in the Brain._data dictionary + # Store the foci separately from overlay data data_foci = coords - if "foci" in self._data.get(hemi, []): - data_foci = np.vstack((self._data[hemi]["foci"], data_foci)) - self._data[hemi] = self._data.get(hemi, dict()) # no data added yet - self._data[hemi]["foci"] = data_foci + if "foci" in self._foci_data.get(hemi, {}): + data_foci = np.vstack((self._foci_data[hemi]["foci"], data_foci)) + self._foci_data.setdefault(hemi, {})["foci"] = data_foci @verbose def add_sensors( @@ -3419,15 +3440,15 @@ def _update_colormap_range(self, fmin=None, fmid=None, fmax=None, alpha=None): """ args = f"{fmin}, {fmid}, {fmax}, {alpha}" logger.debug(f"Updating LUT with {args}") - center = self._data["center"] - colormap = self._data["colormap"] - transparent = self._data["transparent"] - lims = {key: self._data[key] for key in ("fmin", "fmid", "fmax")} + center = self._active_data["center"] + colormap = self._active_data["colormap"] + transparent = self._active_data["transparent"] + lims = {k: self._active_data[k] for k in ("fmin", "fmid", "fmax")} _update_monotonic(lims, fmin=fmin, fmid=fmid, fmax=fmax) assert all(val is not None for val in lims.values()) - self._data.update(lims) - self._data["ctable"] = np.round( + self._active_data.update(lims) + self._active_data["ctable"] = np.round( calculate_lut( colormap, alpha=1.0, center=center, transparent=transparent, **lims ) @@ -3435,15 +3456,15 @@ def _update_colormap_range(self, fmin=None, fmid=None, fmax=None, alpha=None): ).astype(np.uint8) # update our values rng = self._cmap_range - ctable = self._data["ctable"] + ctable = self._active_data["ctable"] for hemi in ["lh", "rh", "vol"]: - hemi_data = self._data.get(hemi) + hemi_data = self._active_data.get(hemi) if hemi_data is not None: if hemi in self.layered_meshes: mesh = self.layered_meshes[hemi] mesh.update_overlay( - name=hemi_data["key"], - colormap=self._data["ctable"], + name=self._active_data_key, + colormap=self._active_data["ctable"], opacity=alpha, rng=rng, ) @@ -3482,7 +3503,7 @@ def set_data_smoothing(self, n_steps): from ...morph import _hemi_morph for hemi in ["lh", "rh"]: - hemi_data = self._data.get(hemi) + hemi_data = self._active_data.get(hemi) if hemi_data is not None: if len(hemi_data["array"]) >= self.geo[hemi].x.shape[0]: continue @@ -3503,11 +3524,11 @@ def set_data_smoothing(self, n_steps): maps=None, warn=False, ) - self._data[hemi]["smooth_mat"] = smooth_mat + self._active_data[hemi]["smooth_mat"] = smooth_mat if hemi in self.layered_meshes: self.layered_meshes[hemi].smooth_mat = smooth_mat - self._update_current_time_idx(self._data["time_idx"]) - self._data["smoothing_steps"] = n_steps + self._update_current_time_idx(self._active_data["time_idx"]) + self._active_data["smoothing_steps"] = n_steps @property def _n_times(self): @@ -3535,17 +3556,19 @@ def set_time_interpolation(self, interpolation): self._time_interp_inv = None if self._times is not None: idx = np.arange(self._n_times) - for hemi in ["lh", "rh", "vol"]: - hemi_data = self._data.get(hemi) - if hemi_data is not None: - array = hemi_data["array"] - self._time_interp_funcs[hemi] = _safe_interp1d( - idx, - array, - self._time_interpolation, - axis=-1, - assume_sorted=True, - ) + for data_key, key_data in self._data.items(): + for hemi in ["lh", "rh", "vol"]: + hemi_data = key_data.get(hemi) + if hemi_data is not None: + array = hemi_data["array"] + if array.ndim > 1: + self._time_interp_funcs[(data_key, hemi)] = _safe_interp1d( + idx, + array, + self._time_interpolation, + axis=-1, + assume_sorted=True, + ) self._time_interp_inv = _safe_interp1d(idx, self._times) def _update_current_time_idx(self, time_idx): @@ -3558,63 +3581,76 @@ def _update_current_time_idx(self, time_idx): between indices. """ self._current_act_data = dict() - time_actor = self._data.get("time_actor", None) - time_label = self._data.get("time_label", None) + active = self._active_data + time_actor = active.get("time_actor", None) + time_label = active.get("time_label", None) for hemi in ["lh", "rh", "vol"]: - hemi_data = self._data.get(hemi) - if hemi_data is not None: + for data_key, key_data in self._data.items(): + hemi_data = key_data.get(hemi) + if hemi_data is None: + continue array = hemi_data["array"] # interpolate in time vectors = None if array.ndim == 1: act_data = array - self._current_time = 0 + if data_key == self._active_data_key: + self._current_time = 0 else: - act_data = self._time_interp_funcs[hemi](time_idx) - self._current_time = self._time_interp_inv(time_idx) + act_data = self._time_interp_funcs[(data_key, hemi)](time_idx) + if data_key == self._active_data_key: + self._current_time = self._time_interp_inv(time_idx) if array.ndim == 3: vectors = act_data act_data = np.linalg.norm(act_data, axis=1) - self._current_time = self._time_interp_inv(time_idx) - self._current_act_data[hemi] = act_data - if time_actor is not None and time_label is not None: - time_actor.SetInput(time_label(self._current_time)) - - # update the volume interpolation - grid = hemi_data.get("grid") - if grid is not None: - vertices = hemi_data["vertices"] - values = self._current_act_data[hemi] - rng = self._cmap_range - fill = 0 if self._data["center"] is not None else rng[0] - grid.point_data["values"].fill(fill) - grid.point_data["values"][vertices] = values - # This can be useful for debugging fsaverage-5 source space by - # making the value at (0, -5, 5) high - # if 21334 in vertices: - # grid.point_data["values"][21334] = values.max() + if data_key == self._active_data_key: + self._current_time = self._time_interp_inv(time_idx) + + if data_key == self._active_data_key: + self._current_act_data[hemi] = act_data + if time_actor is not None and time_label is not None: + time_actor.SetInput(time_label(self._current_time)) + + # update the volume interpolation (active key only) + if data_key == self._active_data_key: + grid = hemi_data.get("grid") + if grid is not None: + vertices = hemi_data["vertices"] + values = self._current_act_data[hemi] + rng = self._cmap_range + fill = 0 if active["center"] is not None else rng[0] + grid.point_data["values"].fill(fill) + grid.point_data["values"][vertices] = values + # This can be useful for debugging fsaverage-5 source space by + # making the value at (0, -5, 5) high + # if 21334 in vertices: + # grid.point_data["values"][21334] = values.max() # update the mesh scalar values (LayeredMesh applies smooth_mat) if hemi in self.layered_meshes: mesh = self.layered_meshes[hemi] - key = hemi_data["key"] - if key in mesh._overlays: - mesh.update_overlay(name=key, scalars=act_data) + key_rng = [ + -key_data["fmax"] if key_data["center"] is not None + else key_data["fmin"], + key_data["fmax"], + ] + if data_key in mesh._overlays: + mesh.update_overlay(name=data_key, scalars=act_data) else: mesh.add_overlay( scalars=act_data, - colormap=self._data["ctable"], - rng=self._cmap_range, + colormap=key_data["ctable"], + rng=key_rng, opacity=None, - name=key, + name=data_key, smooth=True, ) - # update the glyphs - if vectors is not None: + # update the glyphs (active key only) + if vectors is not None and data_key == self._active_data_key: self._update_glyphs(hemi, vectors) - self._data["time_idx"] = time_idx + active["time_idx"] = time_idx self._renderer._update() def set_time_point(self, time_idx): @@ -3655,11 +3691,11 @@ def set_time(self, time): ) def _update_glyphs(self, hemi, vectors): - hemi_data = self._data.get(hemi) + hemi_data = self._active_data.get(hemi) assert hemi_data is not None vertices = hemi_data["vertices"] - vector_alpha = self._data["vector_alpha"] - scale_factor = self._data["scale_factor"] + vector_alpha = self._active_data["vector_alpha"] + scale_factor = self._active_data["scale_factor"] vertices = slice(None) if vertices is None else vertices x, y, z = np.array(self.geo[hemi].coords)[vertices].T @@ -3702,16 +3738,16 @@ def _update_glyphs(self, hemi, vectors): count += 1 self._renderer._set_colormap_range( actor=glyph_actor, - ctable=self._data["ctable"], + ctable=self._active_data["ctable"], scalar_bar=None, rng=self._cmap_range, ) @property def _cmap_range(self): - dt_max = self._data["fmax"] - if self._data["center"] is None: - dt_min = self._data["fmin"] + dt_max = self._active_data["fmax"] + if self._active_data["center"] is None: + dt_min = self._active_data["fmin"] else: dt_min = -1 * dt_max rng = [dt_min, dt_max] @@ -3719,13 +3755,13 @@ def _cmap_range(self): def _update_fscale(self, fscale): """Scale the colorbar points.""" - fmin = self._data["fmin"] * fscale - fmid = self._data["fmid"] * fscale - fmax = self._data["fmax"] * fscale + fmin = self._active_data["fmin"] * fscale + fmid = self._active_data["fmid"] * fscale + fmax = self._active_data["fmax"] * fscale self.update_lut(fmin=fmin, fmid=fmid, fmax=fmax) def _update_auto_scaling(self, restore=False): - user_clim = self._data["clim"] + user_clim = self._active_data["clim"] if user_clim is not None and "lims" in user_clim: allow_pos_lims = False else: @@ -3734,8 +3770,8 @@ def _update_auto_scaling(self, restore=False): clim = user_clim else: clim = "auto" - colormap = self._data["colormap"] - transparent = self._data["transparent"] + colormap = self._active_data["colormap"] + transparent = self._active_data["transparent"] mapdata = _process_clim( clim, colormap, @@ -3750,21 +3786,27 @@ def _update_auto_scaling(self, restore=False): del mapdata fmin, fmid, fmax = scale_pts center = 0.0 if diverging else None - self._data["center"] = center - self._data["colormap"] = colormap - self._data["transparent"] = transparent + self._active_data["center"] = center + self._active_data["colormap"] = colormap + self._active_data["transparent"] = transparent self.update_lut(fmin=fmin, fmid=fmid, fmax=fmax) def _to_time_index(self, value): """Return the interpolated time index of the given time value.""" - time = self._data["time"] + time = self._active_data["time"] value = np.interp(value, time, np.arange(len(time))) return value @property def data(self): """Data used by time viewer and color bar widgets.""" - return self._data + return self._active_data + + @property + def _active_data(self): + if self._active_data_key is None: + return None + return self._data.get(self._active_data_key) @property def labels(self): @@ -4018,7 +4060,7 @@ def _iter_time(self, time_idx, callback): ----- Used by movie and image sequence saving functions. """ - current_time_idx = self._data["time_idx"] + current_time_idx = self._active_data["time_idx"] for ii, idx in enumerate(time_idx): self.set_time_point(idx) if callback is not None: From 90ca56a5910ab42064a71a5bb9b0a9ebb532f244 Mon Sep 17 00:00:00 2001 From: payam Date: Wed, 1 Jul 2026 08:38:05 +0200 Subject: [PATCH 04/12] add test for Brain.add_data with remove_existing=False --- mne/viz/_brain/tests/test_brain.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index b60f1d1fa73..0830d6fc064 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -399,6 +399,24 @@ def __init__(self): lm.update() with pytest.raises(ValueError, match="must have shape"): lm.update_overlay(name="data", scalars=np.ones(1)) + # remove_existing=False keeps the old overlay and adds a new one alongside + assert list(brain._data.keys()) == ["data"] + assert "data" in lm._overlays + brain.add_data( + hemi_data, + fmin=fmin, + hemi="lh", + fmax=fmax, + colormap="Blues", + vertices=hemi_vertices, + smoothing_steps="nearest", + colorbar=False, + key="overlay2", + remove_existing=False, + ) + assert "data" in brain._data and "overlay2" in brain._data + assert "data" in lm._overlays and "overlay2" in lm._overlays + assert brain._active_data_key == "overlay2" brain.remove_data() assert "data" not in brain._actors assert "time_change" not in ui_events._get_event_channel(brain) From 4c46a55a8a7d43a3015d34ad75dd0cb3c2cae065 Mon Sep 17 00:00:00 2001 From: payam Date: Wed, 1 Jul 2026 08:38:49 +0200 Subject: [PATCH 05/12] new example with 2 separate focal activity overlaid together --- examples/visualization/brain.py | 54 +++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/examples/visualization/brain.py b/examples/visualization/brain.py index 9ab4f6d0a9e..af24f85d155 100644 --- a/examples/visualization/brain.py +++ b/examples/visualization/brain.py @@ -20,6 +20,7 @@ # being presented auditory and visual stimuli to display the functionality # of :class:`mne.viz.Brain` for plotting data on a brain. +import numpy as np import matplotlib.pyplot as plt from matplotlib.cm import ScalarMappable from matplotlib.colors import Normalize @@ -208,3 +209,56 @@ smoothing_steps=5, ) brain.show_view(azimuth=190, elevation=70, distance=350, focalpoint=(0, 0, 20)) + +# %% +# Composite two overlays simultaneously +# -------------------------------------- +# +# Pass ``remove_existing=False`` and a distinct ``key`` to keep the first +# overlay visible while adding a second one on top. The two layers are +# alpha-composited by :class:`~mne.viz.LayeredMesh` so both datasets appear +# at the same time. +# +# Here we simulate two focal patches of activity in different brain regions: +# a temporal source (red/hot) and a frontal source (blue). + +brain = mne.viz.Brain( + "sample", + subjects_dir=subjects_dir, + hemi="lh", + alpha=0.1, + background="white", + cortex="low_contrast", +) +coords = brain.geo["lh"].coords # vertex positions in mm + + +def gaussian_patch(coords, center, sigma=15.0): + """Gaussian blob of activity centred on a surface coordinate (mm).""" + d = np.linalg.norm(coords - center, axis=1) + return np.exp(-(d**2) / (2 * sigma**2)) + +temporal = gaussian_patch(coords, center=np.array([-52.0, -18.0, -8.0])) +brain.add_data( + temporal, + hemi="lh", + fmin=0.1, + fmax=1.5, + colormap="hot", + key="temporal", + smoothing_steps=5, +) + +frontal = gaussian_patch(coords, center=np.array([-38.0, 28.0, 46.0])) +brain.add_data( + frontal, + hemi="lh", + fmin=0.1, + fmax=0.6, + colormap="Blues", + alpha=0.5, + key="frontal", + remove_existing=False, + smoothing_steps=5, +) +brain.show_view(azimuth=180, elevation=70, distance=380, focalpoint=(0, 10, 20)) From 90777514b2459b33df9fbd74fd4c99de6ceb0e25 Mon Sep 17 00:00:00 2001 From: payam Date: Wed, 1 Jul 2026 08:40:18 +0200 Subject: [PATCH 06/12] DOC: Add changelog entry for multi-overlay Brain.add_data --- doc/changes/dev/13995.newfeature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 doc/changes/dev/13995.newfeature.rst diff --git a/doc/changes/dev/13995.newfeature.rst b/doc/changes/dev/13995.newfeature.rst new file mode 100644 index 00000000000..f6cb975edb5 --- /dev/null +++ b/doc/changes/dev/13995.newfeature.rst @@ -0,0 +1 @@ +:meth:`~mne.viz.Brain.add_data` now supports multiple simultaneous overlays via ``remove_existing=False`` and distinct ``key`` values. The ``remove_existing`` parameter now accepts ``True`` / ``False`` in addition to ``None``. By `Payam Sadeghi-Shabestari`_. From a6a907921437e58782a4af42aedbfdb470e8d208 Mon Sep 17 00:00:00 2001 From: payam Date: Wed, 1 Jul 2026 09:12:58 +0200 Subject: [PATCH 07/12] FIX: remove_data cleans all overlay keys and keep _active_data when None and fix _get_range to use _active_data --- mne/viz/_brain/_brain.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 67c86648990..11fabc7af67 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -378,7 +378,6 @@ def __init__( # for now only one time label can be added # since it is the same for all figures self._time_label_added = False - # array of data used by TimeViewer; keyed by overlay name self._data = {} self._active_data_key = None self._foci_data = {} @@ -1812,14 +1811,13 @@ def add_data( array = np.asarray(array) vector_alpha = alpha if vector_alpha is None else vector_alpha - # Remove the previously active overlay when not keeping existing ones. + # remove the previously active overlay when not keeping existing ones. old_key = self._active_data_key if remove_existing is not False and old_key is not None and old_key != key: for _hemi in list(self.layered_meshes): self.layered_meshes[_hemi].remove_overlay(old_key) del self._data[old_key] - # Index all per-dataset state by key so multiple overlays can coexist. self._active_data_key = key self._data.setdefault(key, {}) self._data[key]["vector_alpha"] = vector_alpha @@ -1967,13 +1965,7 @@ def add_data( def remove_data(self): """Remove rendered data from the mesh.""" - for key in list(self._data): - for hemi in list(self.layered_meshes): - self.layered_meshes[hemi].remove_overlay(key) - self._remove(key, render=False) - del self._data[key] - self._active_data_key = None - self._renderer._update() + self._remove("data", render=True) # Stop listening to events if "time_change" in _get_event_channel(self): @@ -3804,9 +3796,7 @@ def data(self): @property def _active_data(self): - if self._active_data_key is None: - return None - return self._data.get(self._active_data_key) + return self._data[self._active_data_key] @property def labels(self): @@ -4243,7 +4233,7 @@ def _get_range(brain): multiplied by the scaling factor and when getting a value, this value should be divided by the scaling factor. """ - fmax = abs(brain._data["fmax"]) + fmax = abs(brain._active_data["fmax"]) if 1e-02 <= fmax <= 1e02: fscale_power = 0 else: From d347950390d1477c10e945f0818bd19bf4f18ea1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Jul 2026 07:15:20 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/visualization/brain.py | 3 ++- mne/viz/_brain/_brain.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/visualization/brain.py b/examples/visualization/brain.py index af24f85d155..dc2a89779d9 100644 --- a/examples/visualization/brain.py +++ b/examples/visualization/brain.py @@ -20,8 +20,8 @@ # being presented auditory and visual stimuli to display the functionality # of :class:`mne.viz.Brain` for plotting data on a brain. -import numpy as np import matplotlib.pyplot as plt +import numpy as np from matplotlib.cm import ScalarMappable from matplotlib.colors import Normalize @@ -238,6 +238,7 @@ def gaussian_patch(coords, center, sigma=15.0): d = np.linalg.norm(coords - center, axis=1) return np.exp(-(d**2) / (2 * sigma**2)) + temporal = gaussian_patch(coords, center=np.array([-52.0, -18.0, -8.0])) brain.add_data( temporal, diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 11fabc7af67..da204aba7be 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -1058,7 +1058,9 @@ def _configure_vertex_time_course(self): # initialize the default point if self._active_data["initial_time"] is not None: # pick at that time - use_data = act_data[:, [np.round(self._active_data["time_idx"]).astype(int)]] + use_data = act_data[ + :, [np.round(self._active_data["time_idx"]).astype(int)] + ] else: use_data = act_data ind = np.unravel_index( @@ -1308,7 +1310,9 @@ def _on_colormap_range(self, event): return lims = {key: getattr(event, key) for key in ("fmin", "fmid", "fmax", "alpha")} # Check if limits have changed at all. - if all(val is None or val == self._active_data[key] for key, val in lims.items()): + if all( + val is None or val == self._active_data[key] for key, val in lims.items() + ): return # Update the GUI elements. with disable_ui_events(self): @@ -1318,7 +1322,9 @@ def _on_colormap_range(self, event): self.widgets[key].set_value(val * self._active_data["fscale"]) entry_key = "entry_" + key if entry_key in self.widgets: - self.widgets[entry_key].set_value(val * self._active_data["fscale"]) + self.widgets[entry_key].set_value( + val * self._active_data["fscale"] + ) # Update the render. self._update_colormap_range(**lims) @@ -1487,7 +1493,9 @@ def plot_time_course(self, hemi, vertex_id, color, update=True): xfm = read_talxfm(self._subject, self._subjects_dir) if self._units == "mm": xfm["trans"][:3, 3] *= 1000.0 - ijk = np.unravel_index(vertex_id, self._active_data[hemi]["grid_shape"], order="F") + ijk = np.unravel_index( + vertex_id, self._active_data[hemi]["grid_shape"], order="F" + ) src_mri_t = self._active_data[hemi]["grid_src_mri_t"] mni = apply_trans(xfm["trans"] @ src_mri_t, ijk) else: @@ -3622,7 +3630,8 @@ def _update_current_time_idx(self, time_idx): if hemi in self.layered_meshes: mesh = self.layered_meshes[hemi] key_rng = [ - -key_data["fmax"] if key_data["center"] is not None + -key_data["fmax"] + if key_data["center"] is not None else key_data["fmin"], key_data["fmax"], ] From 1648cf8439cb13fc2b84e83f4daa0d9a292e1b43 Mon Sep 17 00:00:00 2001 From: payam Date: Wed, 1 Jul 2026 09:38:14 +0200 Subject: [PATCH 09/12] FIX: use _active_data and _foci_data after _data restructure instead of _data --- mne/viz/_brain/tests/test_brain.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 0830d6fc064..74a954a3405 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -93,7 +93,7 @@ def GetPickPosition(self): """Return the picked position.""" if self.hemi == "vol": self.point_id = self.cell_id - return self.brain._data["vol"]["grid_coords"][self.cell_id] + return self.brain._active_data["vol"]["grid_coords"][self.cell_id] else: vtk_cell = self.mesh.GetCell(self.cell_id) cell = [ @@ -994,7 +994,7 @@ def test_brain_time_viewer(renderer_interactive_pyvistaqt, pixel_ratio, brain_gc with use_log_level("debug"): brain.update_lut(fmin=12.0) - assert brain._data["fmin"] == 12.0 + assert brain._active_data["fmin"] == 12.0 brain.update_lut(fmax=4.0) _assert_brain_range(brain, [4.0, 4.0]) brain.update_lut(fmid=6.0) @@ -1162,7 +1162,7 @@ def test_brain_traces_vertex( # add foci should work for 'lh', 'rh' and 'vol' for current_hemi in hemi_str: brain.add_foci([[0, 0, 0]], hemi=current_hemi) - assert_array_equal(brain._data[current_hemi]["foci"], [[0, 0, 0]]) + assert_array_equal(brain._foci_data[current_hemi]["foci"], [[0, 0, 0]]) # test points picked by default picked_points = brain.get_picked_points() @@ -1195,8 +1195,8 @@ def test_brain_traces_vertex( for idx, current_hemi in enumerate(hemi_str): assert len(spheres) == 0 if current_hemi == "vol": - current_mesh = brain._data["vol"]["grid"] - vertices = brain._data["vol"]["vertices"] + current_mesh = brain._active_data["vol"]["grid"] + vertices = brain._active_data["vol"]["vertices"] values = current_mesh.point_data["values"][vertices] cell_id = vertices[np.argmax(np.abs(values))] else: @@ -1302,7 +1302,7 @@ def test_brain_traces_colormap(renderer_interactive_pyvistaqt, brain_gc): add_data_kwargs=dict(colorbar_kwargs=dict(n_labels=3)), ) # mne_analyze should be chosen - ctab = brain._data["ctable"] + ctab = brain._active_data["ctable"] assert_array_equal(ctab[0], [0, 255, 255, 255]) # opaque cyan assert_array_equal(ctab[-1], [255, 255, 0, 255]) # opaque yellow assert_allclose(ctab[len(ctab) // 2], [128, 128, 128, 0], atol=3) @@ -1509,7 +1509,7 @@ def test_brain_ui_events(renderer_interactive_pyvistaqt, brain_gc): kind="distributed_source_power", fmin=1, fmid=2, fmax=3, alpha=True ), ) - assert_array_equal(brain._data["ctable"][:3, 3], [0, 2, 4]) + assert_array_equal(brain._active_data["ctable"][:3, 3], [0, 2, 4]) # This event should be ignored. ui_events.publish( @@ -1519,7 +1519,7 @@ def test_brain_ui_events(renderer_interactive_pyvistaqt, brain_gc): ), ) # Should remain unchanged. - assert_array_equal(brain._data["ctable"][:3, 3], [0, 2, 4]) + assert_array_equal(brain._active_data["ctable"][:3, 3], [0, 2, 4]) brain.close() @@ -1604,4 +1604,4 @@ def test_foci_mapping(tmp_path, renderer_interactive_pyvistaqt): tiny_brain, _ = tiny(tmp_path) foci_coords = tiny_brain.geo["lh"].coords[:2] + 0.01 tiny_brain.add_foci(foci_coords, map_surface="white") - assert_array_equal(tiny_brain._data["lh"]["foci"], tiny_brain.geo["lh"].coords[:2]) + assert_array_equal(tiny_brain._foci_data["lh"]["foci"], tiny_brain.geo["lh"].coords[:2]) From 079eb5583f149eb0fc1cbfa5ce683fdae1610c4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Jul 2026 07:42:01 +0000 Subject: [PATCH 10/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/viz/_brain/tests/test_brain.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 74a954a3405..9f30af36955 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -1604,4 +1604,6 @@ def test_foci_mapping(tmp_path, renderer_interactive_pyvistaqt): tiny_brain, _ = tiny(tmp_path) foci_coords = tiny_brain.geo["lh"].coords[:2] + 0.01 tiny_brain.add_foci(foci_coords, map_surface="white") - assert_array_equal(tiny_brain._foci_data["lh"]["foci"], tiny_brain.geo["lh"].coords[:2]) + assert_array_equal( + tiny_brain._foci_data["lh"]["foci"], tiny_brain.geo["lh"].coords[:2] + ) From a6a854c4fa82006d3d657382427a1fb861c36816 Mon Sep 17 00:00:00 2001 From: payam Date: Fri, 3 Jul 2026 07:34:36 +0200 Subject: [PATCH 11/12] use _all_data as overlay store and now _data property returns active overlay --- mne/viz/_brain/_brain.py | 252 ++++++++++++++--------------- mne/viz/_brain/tests/test_brain.py | 18 +-- 2 files changed, 135 insertions(+), 135 deletions(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index da204aba7be..e92e3602a63 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -378,7 +378,7 @@ def __init__( # for now only one time label can be added # since it is the same for all figures self._time_label_added = False - self._data = {} + self._all_data = {} self._active_data_key = None self._foci_data = {} self.geo = {} @@ -515,7 +515,7 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): """ if self.time_viewer: return - if not self._data: + if not self._all_data: raise ValueError("No data to visualize. See ``add_data``.") self.time_viewer = time_viewer self.orientation = list(_lh_views_dict.keys()) @@ -643,7 +643,7 @@ def _clean(self): "actions", "widgets", "geo", - "_data", + "_all_data", ): setattr(self, key, None) self._cleaned = True @@ -714,7 +714,7 @@ def set_playback_speed(self, speed): publish(self, PlaybackSpeed(speed=speed)) def _configure_time_label(self): - self.time_actor = self._active_data.get("time_actor") + self.time_actor = self._data.get("time_actor") if self.time_actor is not None: self.time_actor.SetPosition(0.5, 0.03) self.time_actor.GetTextProperty().SetJustificationToCentered() @@ -728,7 +728,7 @@ def _configure_scalar_bar(self): self._scalar_bar.SetPosition(0.02, 0.2) def _configure_dock_playback_widget(self, name): - len_time = len(self._active_data["time"]) - 1 + len_time = len(self._data["time"]) - 1 # Time widget if len_time < 1: @@ -745,7 +745,7 @@ def current_time_func(): self._renderer._enable_time_interaction( self, current_time_func, - self._active_data["time"], + self._data["time"], self.default_playback_speed_value, self.default_playback_speed_range, ) @@ -753,7 +753,7 @@ def current_time_func(): # Time label current_time = self._current_time assert current_time is not None # should never be the case, float - time_label = self._active_data["time_label"] + time_label = self._data["time_label"] if callable(time_label): current_time = time_label(current_time) else: @@ -822,7 +822,7 @@ def set_orientation(value, orientation_data=orientation_data): def _configure_dock_colormap_widget(self, name): fmax, fscale, fscale_power = _get_range(self) rng = [0, fmax * fscale] - self._active_data["fscale"] = fscale + self._data["fscale"] = fscale layout = self._renderer._dock_add_group_box(name) text = "min / mid / max" @@ -837,14 +837,14 @@ def _configure_dock_colormap_widget(self, name): @_auto_weakref def update_single_lut_value(value, key): # Called by the sliders and spin boxes. - self.update_lut(**{key: value / self._active_data["fscale"]}) + self.update_lut(**{key: value / self._data["fscale"]}) keys = ("fmin", "fmid", "fmax") for key in keys: hlayout = self._renderer._dock_add_layout(vertical=False) self.widgets[key] = self._renderer._dock_add_slider( name=None, - value=self._active_data[key] * self._active_data["fscale"], + value=self._all_data[key] * self._data["fscale"], rng=rng, callback=partial(update_single_lut_value, key=key), double=True, @@ -852,7 +852,7 @@ def update_single_lut_value(value, key): ) self.widgets[f"entry_{key}"] = self._renderer._dock_add_spin_box( name=None, - value=self._active_data[key] * self._active_data["fscale"], + value=self._all_data[key] * self._data["fscale"], callback=partial(update_single_lut_value, key=key), rng=rng, layout=hlayout, @@ -901,8 +901,8 @@ def _configure_dock_trace_widget(self, name): return # do not show trace mode for volumes if ( - self._active_data.get("src", None) is not None - and self._active_data["src"].kind == "volume" + self._data.get("src", None) is not None + and self._data["src"].kind == "volume" ): self._configure_vertex_time_course() return @@ -950,9 +950,9 @@ def _set_label_mode(mode): cands = _read_annot_cands(dir_name, raise_error=False) cands = cands + ["None"] self.annot = cands[0] - stc = self._active_data["stc"] + stc = self._data["stc"] modes = _get_allowed_label_modes(stc) - if self._active_data["src"] is None: + if self._data["src"] is None: modes = [ m for m in modes if m not in self.default_label_extract_modes["src"] ] @@ -986,7 +986,7 @@ def _configure_dock(self): # Smoothing widget self.widgets["smoothing"] = self._renderer._dock_add_spin_box( name="Smoothing", - value=self._active_data["smoothing_steps"], + value=self._data["smoothing_steps"], rng=self.default_smoothing_range, callback=self.set_data_smoothing, double=False, @@ -1002,7 +1002,7 @@ def _configure_mplcanvas(self): show_traces=self.show_traces, separate_canvas=self.separate_canvas, ) - xlim = [np.min(self._active_data["time"]), np.max(self._active_data["time"])] + xlim = [np.min(self._data["time"]), np.max(self._data["time"])] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) self.mpl_canvas.axes.set(xlim=xlim) @@ -1029,7 +1029,7 @@ def _configure_vertex_time_course(self): del y (self.rms,) = self.mpl_canvas.axes.plot( - self._active_data["time"], + self._data["time"], rms, lw=3, label="RMS", @@ -1047,7 +1047,7 @@ def _configure_vertex_time_course(self): act_data = self.act_data_smooth.get(hemi, [None])[0] if act_data is None: continue - hemi_data = self._active_data[hemi] + hemi_data = self._data[hemi] vertices = hemi_data["vertices"] # simulate a picked renderer @@ -1056,11 +1056,9 @@ def _configure_vertex_time_course(self): self._picked_renderer = self._renderer._all_renderers[idx] # initialize the default point - if self._active_data["initial_time"] is not None: + if self._data["initial_time"] is not None: # pick at that time - use_data = act_data[ - :, [np.round(self._active_data["time_idx"]).astype(int)] - ] + use_data = act_data[:, [np.round(self._data["time_idx"]).astype(int)]] else: use_data = act_data ind = np.unravel_index( @@ -1072,7 +1070,7 @@ def _configure_vertex_time_course(self): def _configure_picking(self): # get data for each hemi for idx, hemi in enumerate(["vol", "lh", "rh"]): - hemi_data = self._active_data.get(hemi) + hemi_data = self._data.get(hemi) if hemi_data is not None: act_data = hemi_data["array"] if act_data.ndim == 3: @@ -1250,9 +1248,9 @@ def _on_pick(self, vtk_picker, event): # camera-to-click array, which fortunately we can get "just" # by inspecting the points that are sufficiently close to the # ray. - grid = self._active_data[hemi]["grid"] - vertices = self._active_data[hemi]["vertices"] - coords = self._active_data[hemi]["grid_coords"][vertices] + grid = self._data[hemi]["grid"] + vertices = self._data[hemi]["vertices"] + coords = self._data[hemi]["grid_coords"][vertices] scalars = grid.point_data["values"][vertices] spacing = np.array(grid.GetSpacing()) max_dist = np.max(spacing) / 2.0 @@ -1310,21 +1308,17 @@ def _on_colormap_range(self, event): return lims = {key: getattr(event, key) for key in ("fmin", "fmid", "fmax", "alpha")} # Check if limits have changed at all. - if all( - val is None or val == self._active_data[key] for key, val in lims.items() - ): + if all(val is None or val == self._all_data[key] for key, val in lims.items()): return # Update the GUI elements. with disable_ui_events(self): for key, val in lims.items(): if val is not None: if key in self.widgets: - self.widgets[key].set_value(val * self._active_data["fscale"]) + self.widgets[key].set_value(val * self._data["fscale"]) entry_key = "entry_" + key if entry_key in self.widgets: - self.widgets[entry_key].set_value( - val * self._active_data["fscale"] - ) + self.widgets[entry_key].set_value(val * self._data["fscale"]) # Update the render. self._update_colormap_range(**lims) @@ -1332,7 +1326,7 @@ def _on_vertex_select(self, event): """Respond to vertex_select UI event.""" if event.hemi == "vol": try: - mesh = self._active_data[event.hemi]["grid"] + mesh = self._data[event.hemi]["grid"] except KeyError: return else: @@ -1486,17 +1480,15 @@ def plot_time_course(self, hemi, vertex_id, color, update=True): """ if self.mpl_canvas is None: return - time = self._active_data["time"].copy() # avoid circular ref + time = self._data["time"].copy() # avoid circular ref mni = None if hemi == "vol": hemi_str = "V" xfm = read_talxfm(self._subject, self._subjects_dir) if self._units == "mm": xfm["trans"][:3, 3] *= 1000.0 - ijk = np.unravel_index( - vertex_id, self._active_data[hemi]["grid_shape"], order="F" - ) - src_mri_t = self._active_data[hemi]["grid_src_mri_t"] + ijk = np.unravel_index(vertex_id, self._data[hemi]["grid_shape"], order="F") + src_mri_t = self._data[hemi]["grid_src_mri_t"] mni = apply_trans(xfm["trans"] @ src_mri_t, ijk) else: hemi_str = "L" if hemi == "lh" else "R" @@ -1824,12 +1816,12 @@ def add_data( if remove_existing is not False and old_key is not None and old_key != key: for _hemi in list(self.layered_meshes): self.layered_meshes[_hemi].remove_overlay(old_key) - del self._data[old_key] + del self._all_data[old_key] self._active_data_key = key - self._data.setdefault(key, {}) - self._data[key]["vector_alpha"] = vector_alpha - self._data[key]["scale_factor"] = scale_factor + self._all_data.setdefault(key, {}) + self._all_data[key]["vector_alpha"] = vector_alpha + self._all_data[key]["scale_factor"] = scale_factor # Create time array and add label if > 1D if array.ndim <= 1: @@ -1845,7 +1837,7 @@ def add_data( f"time has shape {time.shape}, but need shape " f"{(array.shape[-1],)} (array.shape[-1])" ) - self._data[key]["time"] = time + self._all_data[key]["time"] = time if self._n_times is None: self._times = time @@ -1892,29 +1884,29 @@ def add_data( f"{type(smoothing_steps)} was given." ) - self._data[key]["stc"] = stc - self._data[key]["src"] = src - self._data[key]["smoothing_steps"] = smoothing_steps - self._data[key]["clim"] = clim - self._data[key]["time"] = time - self._data[key]["initial_time"] = initial_time - self._data[key]["time_label"] = time_label - self._data[key]["initial_time_idx"] = time_idx - self._data[key]["time_idx"] = time_idx - self._data[key]["transparent"] = transparent + self._all_data[key]["stc"] = stc + self._all_data[key]["src"] = src + self._all_data[key]["smoothing_steps"] = smoothing_steps + self._all_data[key]["clim"] = clim + self._all_data[key]["time"] = time + self._all_data[key]["initial_time"] = initial_time + self._all_data[key]["time_label"] = time_label + self._all_data[key]["initial_time_idx"] = time_idx + self._all_data[key]["time_idx"] = time_idx + self._all_data[key]["transparent"] = transparent # data specific for a hemi - self._data[key][hemi] = dict() - self._data[key][hemi]["glyph_dataset"] = None - self._data[key][hemi]["glyph_mapper"] = None - self._data[key][hemi]["glyph_actor"] = None - self._data[key][hemi]["array"] = array - self._data[key][hemi]["vertices"] = vertices - self._data[key]["alpha"] = alpha - self._data[key]["colormap"] = colormap - self._data[key]["center"] = center - self._data[key]["fmin"] = fmin - self._data[key]["fmid"] = fmid - self._data[key]["fmax"] = fmax + self._all_data[key][hemi] = dict() + self._all_data[key][hemi]["glyph_dataset"] = None + self._all_data[key][hemi]["glyph_mapper"] = None + self._all_data[key][hemi]["glyph_actor"] = None + self._all_data[key][hemi]["array"] = array + self._all_data[key][hemi]["vertices"] = vertices + self._all_data[key]["alpha"] = alpha + self._all_data[key]["colormap"] = colormap + self._all_data[key]["center"] = center + self._all_data[key]["fmin"] = fmin + self._all_data[key]["fmid"] = fmid + self._all_data[key]["fmax"] = fmax self._update_colormap_range() # 1) add the surfaces first @@ -1932,7 +1924,7 @@ def add_data( # set_data_smoothing calls "_update_current_time_idx" for us, which will set # _current_time self.set_time_interpolation(self.time_interpolation) - self.set_data_smoothing(self._data[key]["smoothing_steps"]) + self.set_data_smoothing(self._all_data[key]["smoothing_steps"]) # 3) add the other actors if colorbar is True: @@ -1950,7 +1942,7 @@ def add_data( text=time_label(self._current_time), justification="right", ) - self._data[key]["time_actor"] = time_actor + self._all_data[key]["time_actor"] = time_actor self._time_label_added = True if colorbar and self._scalar_bar is None and do: kwargs = dict( @@ -1973,7 +1965,13 @@ def add_data( def remove_data(self): """Remove rendered data from the mesh.""" - self._remove("data", render=True) + for key in list(self._all_data): + for hemi in list(self.layered_meshes): + self.layered_meshes[hemi].remove_overlay(key) + self._remove(key, render=False) + del self._all_data[key] + self._active_data_key = None + self._renderer._update() # Stop listening to events if "time_change" in _get_event_channel(self): @@ -2058,7 +2056,7 @@ def _add_volume_data(self, hemi, src, volume_options): ) alpha = volume_options["alpha"] if alpha is None: - alpha = 0.4 if self._active_data[hemi]["array"].ndim == 3 else 1.0 + alpha = 0.4 if self._data[hemi]["array"].ndim == 3 else 1.0 alpha = np.clip(float(alpha), 0.0, 1.0) resolution = volume_options["resolution"] surface_alpha = volume_options["surface_alpha"] @@ -2069,9 +2067,9 @@ def _add_volume_data(self, hemi, src, volume_options): silhouette_alpha = surface_alpha / 4.0 silhouette_linewidth = volume_options["silhouette_linewidth"] del volume_options - volume_pos = self._active_data[hemi].get("grid_volume_pos") - volume_neg = self._active_data[hemi].get("grid_volume_neg") - center = self._active_data["center"] + volume_pos = self._data[hemi].get("grid_volume_pos") + volume_neg = self._data[hemi].get("grid_volume_neg") + center = self._data["center"] if volume_pos is None: xyz = np.meshgrid(*[np.arange(s) for s in src[0]["shape"]], indexing="ij") dimensions = np.array(src[0]["shape"], int) @@ -2084,8 +2082,8 @@ def _add_volume_data(self, hemi, src, volume_options): coords = np.array([c.ravel(order="F") for c in xyz]).T coords = apply_trans(src_mri_t, coords) self.geo[hemi] = Bunch(coords=coords) - vertices = self._active_data[hemi]["vertices"] - assert self._active_data[hemi]["array"].shape[0] == len(vertices) + vertices = self._data[hemi]["vertices"] + assert self._data[hemi]["array"].shape[0] == len(vertices) # MNE constructs the source space on a uniform grid in MRI space, # but mne coreg can change it to be non-uniform, so we need to # use all three elements here @@ -2105,14 +2103,14 @@ def _add_volume_data(self, hemi, src, volume_options): center, interpolation, ) - self._active_data[hemi]["alpha"] = alpha # incorrectly set earlier - self._active_data[hemi]["grid"] = grid - self._active_data[hemi]["grid_mesh"] = grid_mesh - self._active_data[hemi]["grid_coords"] = coords - self._active_data[hemi]["grid_src_mri_t"] = src_mri_t - self._active_data[hemi]["grid_shape"] = dimensions - self._active_data[hemi]["grid_volume_pos"] = volume_pos - self._active_data[hemi]["grid_volume_neg"] = volume_neg + self._data[hemi]["alpha"] = alpha # incorrectly set earlier + self._data[hemi]["grid"] = grid + self._data[hemi]["grid_mesh"] = grid_mesh + self._data[hemi]["grid_coords"] = coords + self._data[hemi]["grid_src_mri_t"] = src_mri_t + self._data[hemi]["grid_shape"] = dimensions + self._data[hemi]["grid_volume_pos"] = volume_pos + self._data[hemi]["grid_volume_neg"] = volume_neg actor_pos, _ = self._renderer.plotter.add_actor( volume_pos, name=None, culling=False, reset_camera=False, render=False ) @@ -2121,7 +2119,7 @@ def _add_volume_data(self, hemi, src, volume_options): actor_neg, _ = self._renderer.plotter.add_actor( volume_neg, name=None, culling=False, reset_camera=False, render=False ) - grid_mesh = self._active_data[hemi]["grid_mesh"] + grid_mesh = self._data[hemi]["grid_mesh"] if grid_mesh is not None: actor_mesh, prop = self._renderer.plotter.add_actor( grid_mesh, @@ -2252,15 +2250,15 @@ def add_label( ids = ids[scalars >= scalar_thresh] if self.time_viewer and self.show_traces and self.traces_mode == "label": - stc = self._active_data["stc"] - src = self._active_data["src"] + stc = self._data["stc"] + src = self._data["src"] tc = stc.extract_label_time_course( label, src=src, mode=self.label_extract_mode ) tc = tc[0] if tc.ndim == 2 else tc[0, 0, :] color = next(self.color_cycle) line = self.mpl_canvas.plot( - self._active_data["time"], tc, label=label_name, color=color + self._data["time"], tc, label=label_name, color=color ) else: line = None @@ -3440,15 +3438,15 @@ def _update_colormap_range(self, fmin=None, fmid=None, fmax=None, alpha=None): """ args = f"{fmin}, {fmid}, {fmax}, {alpha}" logger.debug(f"Updating LUT with {args}") - center = self._active_data["center"] - colormap = self._active_data["colormap"] - transparent = self._active_data["transparent"] - lims = {k: self._active_data[k] for k in ("fmin", "fmid", "fmax")} + center = self._data["center"] + colormap = self._data["colormap"] + transparent = self._data["transparent"] + lims = {k: self._data[k] for k in ("fmin", "fmid", "fmax")} _update_monotonic(lims, fmin=fmin, fmid=fmid, fmax=fmax) assert all(val is not None for val in lims.values()) - self._active_data.update(lims) - self._active_data["ctable"] = np.round( + self._data.update(lims) + self._data["ctable"] = np.round( calculate_lut( colormap, alpha=1.0, center=center, transparent=transparent, **lims ) @@ -3456,15 +3454,15 @@ def _update_colormap_range(self, fmin=None, fmid=None, fmax=None, alpha=None): ).astype(np.uint8) # update our values rng = self._cmap_range - ctable = self._active_data["ctable"] + ctable = self._data["ctable"] for hemi in ["lh", "rh", "vol"]: - hemi_data = self._active_data.get(hemi) + hemi_data = self._data.get(hemi) if hemi_data is not None: if hemi in self.layered_meshes: mesh = self.layered_meshes[hemi] mesh.update_overlay( name=self._active_data_key, - colormap=self._active_data["ctable"], + colormap=self._data["ctable"], opacity=alpha, rng=rng, ) @@ -3503,7 +3501,7 @@ def set_data_smoothing(self, n_steps): from ...morph import _hemi_morph for hemi in ["lh", "rh"]: - hemi_data = self._active_data.get(hemi) + hemi_data = self._data.get(hemi) if hemi_data is not None: if len(hemi_data["array"]) >= self.geo[hemi].x.shape[0]: continue @@ -3524,11 +3522,11 @@ def set_data_smoothing(self, n_steps): maps=None, warn=False, ) - self._active_data[hemi]["smooth_mat"] = smooth_mat + self._data[hemi]["smooth_mat"] = smooth_mat if hemi in self.layered_meshes: self.layered_meshes[hemi].smooth_mat = smooth_mat - self._update_current_time_idx(self._active_data["time_idx"]) - self._active_data["smoothing_steps"] = n_steps + self._update_current_time_idx(self._data["time_idx"]) + self._data["smoothing_steps"] = n_steps @property def _n_times(self): @@ -3556,7 +3554,7 @@ def set_time_interpolation(self, interpolation): self._time_interp_inv = None if self._times is not None: idx = np.arange(self._n_times) - for data_key, key_data in self._data.items(): + for data_key, key_data in self._all_data.items(): for hemi in ["lh", "rh", "vol"]: hemi_data = key_data.get(hemi) if hemi_data is not None: @@ -3581,11 +3579,11 @@ def _update_current_time_idx(self, time_idx): between indices. """ self._current_act_data = dict() - active = self._active_data + active = self._data time_actor = active.get("time_actor", None) time_label = active.get("time_label", None) for hemi in ["lh", "rh", "vol"]: - for data_key, key_data in self._data.items(): + for data_key, key_data in self._all_data.items(): hemi_data = key_data.get(hemi) if hemi_data is None: continue @@ -3692,11 +3690,11 @@ def set_time(self, time): ) def _update_glyphs(self, hemi, vectors): - hemi_data = self._active_data.get(hemi) + hemi_data = self._data.get(hemi) assert hemi_data is not None vertices = hemi_data["vertices"] - vector_alpha = self._active_data["vector_alpha"] - scale_factor = self._active_data["scale_factor"] + vector_alpha = self._data["vector_alpha"] + scale_factor = self._data["scale_factor"] vertices = slice(None) if vertices is None else vertices x, y, z = np.array(self.geo[hemi].coords)[vertices].T @@ -3739,16 +3737,16 @@ def _update_glyphs(self, hemi, vectors): count += 1 self._renderer._set_colormap_range( actor=glyph_actor, - ctable=self._active_data["ctable"], + ctable=self._data["ctable"], scalar_bar=None, rng=self._cmap_range, ) @property def _cmap_range(self): - dt_max = self._active_data["fmax"] - if self._active_data["center"] is None: - dt_min = self._active_data["fmin"] + dt_max = self._data["fmax"] + if self._data["center"] is None: + dt_min = self._data["fmin"] else: dt_min = -1 * dt_max rng = [dt_min, dt_max] @@ -3756,13 +3754,13 @@ def _cmap_range(self): def _update_fscale(self, fscale): """Scale the colorbar points.""" - fmin = self._active_data["fmin"] * fscale - fmid = self._active_data["fmid"] * fscale - fmax = self._active_data["fmax"] * fscale + fmin = self._data["fmin"] * fscale + fmid = self._data["fmid"] * fscale + fmax = self._data["fmax"] * fscale self.update_lut(fmin=fmin, fmid=fmid, fmax=fmax) def _update_auto_scaling(self, restore=False): - user_clim = self._active_data["clim"] + user_clim = self._data["clim"] if user_clim is not None and "lims" in user_clim: allow_pos_lims = False else: @@ -3771,8 +3769,8 @@ def _update_auto_scaling(self, restore=False): clim = user_clim else: clim = "auto" - colormap = self._active_data["colormap"] - transparent = self._active_data["transparent"] + colormap = self._data["colormap"] + transparent = self._data["transparent"] mapdata = _process_clim( clim, colormap, @@ -3787,25 +3785,27 @@ def _update_auto_scaling(self, restore=False): del mapdata fmin, fmid, fmax = scale_pts center = 0.0 if diverging else None - self._active_data["center"] = center - self._active_data["colormap"] = colormap - self._active_data["transparent"] = transparent + self._data["center"] = center + self._data["colormap"] = colormap + self._data["transparent"] = transparent self.update_lut(fmin=fmin, fmid=fmid, fmax=fmax) def _to_time_index(self, value): """Return the interpolated time index of the given time value.""" - time = self._active_data["time"] + time = self._data["time"] value = np.interp(value, time, np.arange(len(time))) return value @property def data(self): """Data used by time viewer and color bar widgets.""" - return self._active_data + return self._data @property - def _active_data(self): - return self._data[self._active_data_key] + def _data(self): + if self._active_data_key is None: + return None + return self._all_data.get(self._active_data_key) @property def labels(self): @@ -4059,7 +4059,7 @@ def _iter_time(self, time_idx, callback): ----- Used by movie and image sequence saving functions. """ - current_time_idx = self._active_data["time_idx"] + current_time_idx = self._data["time_idx"] for ii, idx in enumerate(time_idx): self.set_time_point(idx) if callback is not None: @@ -4242,7 +4242,7 @@ def _get_range(brain): multiplied by the scaling factor and when getting a value, this value should be divided by the scaling factor. """ - fmax = abs(brain._active_data["fmax"]) + fmax = abs(brain._data["fmax"]) if 1e-02 <= fmax <= 1e02: fscale_power = 0 else: diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 9f30af36955..c2f430af95d 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -93,7 +93,7 @@ def GetPickPosition(self): """Return the picked position.""" if self.hemi == "vol": self.point_id = self.cell_id - return self.brain._active_data["vol"]["grid_coords"][self.cell_id] + return self.brain._data["vol"]["grid_coords"][self.cell_id] else: vtk_cell = self.mesh.GetCell(self.cell_id) cell = [ @@ -400,7 +400,7 @@ def __init__(self): with pytest.raises(ValueError, match="must have shape"): lm.update_overlay(name="data", scalars=np.ones(1)) # remove_existing=False keeps the old overlay and adds a new one alongside - assert list(brain._data.keys()) == ["data"] + assert list(brain._all_data.keys()) == ["data"] assert "data" in lm._overlays brain.add_data( hemi_data, @@ -414,7 +414,7 @@ def __init__(self): key="overlay2", remove_existing=False, ) - assert "data" in brain._data and "overlay2" in brain._data + assert "data" in brain._all_data and "overlay2" in brain._all_data assert "data" in lm._overlays and "overlay2" in lm._overlays assert brain._active_data_key == "overlay2" brain.remove_data() @@ -994,7 +994,7 @@ def test_brain_time_viewer(renderer_interactive_pyvistaqt, pixel_ratio, brain_gc with use_log_level("debug"): brain.update_lut(fmin=12.0) - assert brain._active_data["fmin"] == 12.0 + assert brain._data["fmin"] == 12.0 brain.update_lut(fmax=4.0) _assert_brain_range(brain, [4.0, 4.0]) brain.update_lut(fmid=6.0) @@ -1195,8 +1195,8 @@ def test_brain_traces_vertex( for idx, current_hemi in enumerate(hemi_str): assert len(spheres) == 0 if current_hemi == "vol": - current_mesh = brain._active_data["vol"]["grid"] - vertices = brain._active_data["vol"]["vertices"] + current_mesh = brain._data["vol"]["grid"] + vertices = brain._data["vol"]["vertices"] values = current_mesh.point_data["values"][vertices] cell_id = vertices[np.argmax(np.abs(values))] else: @@ -1302,7 +1302,7 @@ def test_brain_traces_colormap(renderer_interactive_pyvistaqt, brain_gc): add_data_kwargs=dict(colorbar_kwargs=dict(n_labels=3)), ) # mne_analyze should be chosen - ctab = brain._active_data["ctable"] + ctab = brain._data["ctable"] assert_array_equal(ctab[0], [0, 255, 255, 255]) # opaque cyan assert_array_equal(ctab[-1], [255, 255, 0, 255]) # opaque yellow assert_allclose(ctab[len(ctab) // 2], [128, 128, 128, 0], atol=3) @@ -1509,7 +1509,7 @@ def test_brain_ui_events(renderer_interactive_pyvistaqt, brain_gc): kind="distributed_source_power", fmin=1, fmid=2, fmax=3, alpha=True ), ) - assert_array_equal(brain._active_data["ctable"][:3, 3], [0, 2, 4]) + assert_array_equal(brain._data["ctable"][:3, 3], [0, 2, 4]) # This event should be ignored. ui_events.publish( @@ -1519,7 +1519,7 @@ def test_brain_ui_events(renderer_interactive_pyvistaqt, brain_gc): ), ) # Should remain unchanged. - assert_array_equal(brain._active_data["ctable"][:3, 3], [0, 2, 4]) + assert_array_equal(brain._data["ctable"][:3, 3], [0, 2, 4]) brain.close() From 54978dedd065a99d545512966265c41ef4fcb656 Mon Sep 17 00:00:00 2001 From: payam Date: Fri, 3 Jul 2026 07:55:43 +0200 Subject: [PATCH 12/12] use _all_data as overlay store and now _data property returns active overlay --- mne/viz/_brain/_brain.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index e92e3602a63..f926fe51ead 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -844,7 +844,7 @@ def update_single_lut_value(value, key): hlayout = self._renderer._dock_add_layout(vertical=False) self.widgets[key] = self._renderer._dock_add_slider( name=None, - value=self._all_data[key] * self._data["fscale"], + value=self._data[key] * self._data["fscale"], rng=rng, callback=partial(update_single_lut_value, key=key), double=True, @@ -852,7 +852,7 @@ def update_single_lut_value(value, key): ) self.widgets[f"entry_{key}"] = self._renderer._dock_add_spin_box( name=None, - value=self._all_data[key] * self._data["fscale"], + value=self._data[key] * self._data["fscale"], callback=partial(update_single_lut_value, key=key), rng=rng, layout=hlayout, @@ -1308,7 +1308,7 @@ def _on_colormap_range(self, event): return lims = {key: getattr(event, key) for key in ("fmin", "fmid", "fmax", "alpha")} # Check if limits have changed at all. - if all(val is None or val == self._all_data[key] for key, val in lims.items()): + if all(val is None or val == self._data[key] for key, val in lims.items()): return # Update the GUI elements. with disable_ui_events(self):