Skip to content

Add a new widget to show the units topological distribution on the probe #3142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ widgets = [
"matplotlib",
"ipympl",
"ipywidgets",
"seaborn>=0.13.0",
"sortingview>=0.12.0",
]

Expand Down
150 changes: 150 additions & 0 deletions src/spikeinterface/widgets/unit_spatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from __future__ import annotations

import numpy as np
from probeinterface import Probe
from warnings import warn
from .base import BaseWidget, to_attr


class UnitSpatialDistributionsWidget(BaseWidget):
"""
Placeholder documentation to be changed.

Parameters
----------
sorting_analyzer : SortingAnalyzer
The SortingAnalyzer object
depth_axis : int, default: 1
The dimension of unit_locations that is depth
"""

def __init__(
self,
sorting_analyzer,
probe=None,
depth_axis=1,
bins=None,
cmap="viridis",
kde=False,
depth_hist=True,
groups=None,
kde_kws=None,
backend=None,
**backend_kwargs,
):
sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

self.check_extensions(sorting_analyzer, "unit_locations")
ulc = sorting_analyzer.get_extension("unit_locations")
unit_locations = ulc.get_data(outputs="numpy")
x, y = unit_locations[:, 0], unit_locations[:, 1]

if type(probe) is Probe:
if sorting_analyzer.recording.has_probe():
warn(
"There is a Probe attached to this recording, but the probe argument is not None: the attached Probe will be ignored."
)
elif sorting_analyzer.recording.has_probe():
probe = sorting_analyzer.get_probe()
else:
raise ValueError(
"There is no Probe attached to this recording. Use set_probe(...) to attach one or pass it to the function via the probe argument."
)

# xrange, yrange, _ = get_auto_lims(probe, margin=0)
# if bins is None:
# bins = (
# np.round(np.diff(xrange).squeeze() / 75).astype(int),
# np.round(np.diff(yrange).squeeze() / 75).astype(int),
# )
# # TODO: change behaviour, if bins is not defined, bin only along the depth axis

plot_data = dict(
probe=probe,
x=x,
y=y,
depth_axis=depth_axis,
bins=bins,
kde=kde,
cmap=cmap,
depth_hist=depth_hist,
groups=groups,
kde_kws=kde_kws,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.patches as patches
import matplotlib.path as path
from probeinterface.plotting import get_auto_lims
from seaborn import color_palette, kdeplot, histplot
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)
xrange, yrange, _ = get_auto_lims(dp.probe, margin=0)
cmap = color_palette(dp.cmap, as_cmap=True) if type(dp.cmap) is str else dp.cmap

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

ax = self.ax

custom_shape = path.Path(dp.probe.probe_planar_contour)
patch = patches.PathPatch(custom_shape, facecolor="none", edgecolor="none")
ax.add_patch(patch)

if dp.kde is not True:
hist, xedges, yedges = np.histogram2d(dp.x, dp.y, bins=dp.bins, range=[xrange, yrange])
pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap=cmap)
else:
kde_kws = dict(levels=100, thresh=0, fill=True, bw_adjust=0.1)
if dp.kde_kws is not None:
kde_kws.update(dp.kde_kws)
data = dict(x=dp.x, y=dp.y)
bg = ax.add_patch(
patches.Rectangle(
[xrange[0], yrange[0]],
np.diff(xrange).squeeze(),
np.diff(yrange).squeeze(),
facecolor=cmap.colors[0],
fill=True,
)
)
bg.set_clip_path(patch)
kdeplot(data, x="x", y="y", clip=[xrange, yrange], cmap=cmap, ax=ax, **kde_kws)
pcm = ax.collections[0]
ax.set_xlabel(None)
ax.set_ylabel(None)

pcm.set_clip_path(patch)

xlim, ylim, _ = get_auto_lims(dp.probe, margin=10)
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xticks([])
ax.set_xlabel("")
ax.set_ylabel("Depth (um)")

if dp.depth_hist is True:
bbox = ax.get_window_extent()
hist_height = 1.5 * bbox.width

ax_hist = ax.inset_axes([1, 0, hist_height / bbox.width, 1])
data = dict(y=dp.y)
data["group"] = np.ones(dp.y.size) if dp.groups is None else dp.groups
palette = color_palette("bright", n_colors=1 if dp.groups is None else np.unique(dp.groups).size)
histplot(
data=data,
y="y",
hue="group",
bins=dp.bins[1],
binrange=yrange,
palette=palette,
ax=ax_hist,
legend=False,
)
ax_hist.axis("off")
ax_hist.set_ylim(*ylim)
3 changes: 3 additions & 0 deletions src/spikeinterface/widgets/widget_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .unit_locations import UnitLocationsWidget
from .unit_presence import UnitPresenceWidget
from .unit_probe_map import UnitProbeMapWidget
from .unit_spatial import UnitSpatialDistributionsWidget
from .unit_summary import UnitSummaryWidget
from .unit_templates import UnitTemplatesWidget
from .unit_waveforms_density_map import UnitWaveformDensityMapWidget
Expand Down Expand Up @@ -67,6 +68,7 @@
UnitLocationsWidget,
UnitPresenceWidget,
UnitProbeMapWidget,
UnitSpatialDistributionsWidget,
UnitSummaryWidget,
UnitTemplatesWidget,
UnitWaveformDensityMapWidget,
Expand Down Expand Up @@ -142,6 +144,7 @@
plot_unit_locations = UnitLocationsWidget
plot_unit_presence = UnitPresenceWidget
plot_unit_probe_map = UnitProbeMapWidget
plot_unit_spatial_distribution = UnitSpatialDistributionsWidget
plot_unit_summary = UnitSummaryWidget
plot_unit_templates = UnitTemplatesWidget
plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget
Expand Down