diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index a7f40a9558..84bab9447a 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -1,9 +1,12 @@ from __future__ import annotations from pathlib import Path +import random from typing import Union from packaging import version +from spikeinterface.core.generate import _ensure_seed + from ..basesorter import BaseSorter from .kilosortbase import KilosortBase @@ -57,6 +60,7 @@ class Kilosort4Sorter(BaseSorter): "skip_kilosort_preprocessing": False, "scaleproc": None, "torch_device": "auto", + "seed": None, } _params_description = { @@ -99,6 +103,7 @@ class Kilosort4Sorter(BaseSorter): "skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing", "scaleproc": "int16 scaling of whitened data, if None set to 200.", "torch_device": "Select the torch device auto/cuda/cpu", + "seed": "Kilosort random seed", } sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching. @@ -244,9 +249,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels())) ops["Nbatches"] = bfile.n_batches - np.random.seed(1) - torch.cuda.manual_seed_all(1) - torch.random.manual_seed(1) + params["seed"] = _ensure_seed(params["seed"]) + random.seed(params["seed"]) + np.random.seed(params["seed"]) + torch.cuda.manual_seed_all(params["seed"]) + torch.random.manual_seed(params["seed"]) # if not params["skip_kilosort_preprocessing"]: if not params["do_correction"]: print("Skipping drift correction.") diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index cf6933c9e6..aea5056933 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -1,12 +1,14 @@ from __future__ import annotations from pathlib import Path +import random from packaging.version import parse import shutil import numpy as np from warnings import warn +from spikeinterface.core.generate import _ensure_seed from spikeinterface.preprocessing import bandpass_filter, whiten from spikeinterface.core.baserecording import BaseRecording @@ -45,6 +47,7 @@ class Mountainsort5Sorter(BaseSorter): "filter": True, "whiten": True, # Important to do whitening "delete_temporary_recording": True, + "seed": None, } _params_description = { @@ -69,6 +72,7 @@ class Mountainsort5Sorter(BaseSorter): "filter": "Enable or disable filter", "whiten": "Enable or disable whitening", "delete_temporary_recording": "If True, the temporary recording file is deleted after sorting (this may fail on Windows requiring the end-user to delete the file themselves later)", + "seed": "Random seed", } sorter_description = "MountainSort5 uses Isosplit clustering. It is an updated version of MountainSort4. See https://doi.org/10.1016/j.neuron.2017.08.030" @@ -184,6 +188,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"] ) + params["seed"] = _ensure_seed(params["seed"]) + random.seed(params["seed"]) + np.random.seed(params["seed"]) + if not recording.is_binary_compatible(): recording_cached = recording.save(folder=sorter_output_folder / "recording", **get_job_kwargs(p, verbose)) else: