diff --git a/neo/rawio/spikeglxrawio.py b/neo/rawio/spikeglxrawio.py index 9e0508de3..0e4297695 100644 --- a/neo/rawio/spikeglxrawio.py +++ b/neo/rawio/spikeglxrawio.py @@ -204,6 +204,7 @@ def _parse_header(self): # This is true only in case of 'nidq' stream for stream_name in stream_names: if "nidq" in stream_name: + #TODO: loop over all segments to add nidq events to _events_memmap info = self.signals_info_dict[0, stream_name] if len(info["digital_channels"]) > 0: # add event channels @@ -292,33 +293,113 @@ def _event_count(self, event_channel_idx, block_index=None, seg_index=None): return timestamps.size def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start=None, t_stop=None): + #TODO: fix seg_index usage, currently hardcoded for first segment timestamps, durations, labels = [], None, [] info = self.signals_info_dict[0, "nidq"] # There are no events that are not in the nidq stream dig_ch = info["digital_channels"] if len(dig_ch) > 0: event_data = self._events_memmap - channel = dig_ch[event_channel_index] + channel = dig_ch[event_channel_index] # 'XD0', 'XD1', etc. ch_idx = 7 - int(channel[2:]) # They are in the reverse order this_stream = event_data[:, ch_idx] - this_rising = np.where(np.diff(this_stream) == 1)[0] + 1 - this_falling = ( - np.where(np.diff(this_stream) == 255)[0] + 1 - ) # because the data is in unsigned 8 bit, -1 = 255! - if len(this_rising) > 0: - timestamps.extend(this_rising) - labels.extend([f"{channel} ON"] * len(this_rising)) - if len(this_falling) > 0: - timestamps.extend(this_falling) - labels.extend([f"{channel} OFF"] * len(this_falling)) + timestamps, durations, labels = self._find_events_in_channel(this_stream, channel) timestamps = np.asarray(timestamps) if len(labels) == 0: labels = np.asarray(labels, dtype="U1") else: labels = np.asarray(labels) + return timestamps, durations, labels + + def _get_sync_events(self, stream_index, seg_index = 0): + ''' + Find sync events in the stream. + + For imec streams, the sync events are found in the 6th bit + of the'SY0' channel, which should be the last 'analog' channel in the stream. + + For nidq streams, the sync events are found in the channel specified by the metadata fields + 'syncNiChanType' and 'syncNiChan'. + + Meta file descriptions taken from: + https://billkarsh.github.io/SpikeGLX/Sgl_help/Metadata_30.html + + Returns (timestamps, labels) + timestamps in samples of each edge + labels is a list of ('channel_name ON') or OFF for rising or falling edges + ''' + if stream_index > len(self.header["signal_streams"]): + raise ValueError("stream_index out of range") + + stream_name = self.header["signal_streams"][stream_index]["name"] + info = self.signals_info_dict[seg_index, stream_name] + + if 'imec' in stream_name: + if not self.load_sync_channel: + raise ValueError("SYNC channel was not loaded. Try setting load_sync_channel=True") + if not info["has_sync_trace"]: + raise ValueError("SYNC channel is not present in the recording." + " Cannot find sync events based on metadata field 'snsApLfSy'") + + #find sync events in the 'SY0' channel of imec streams + channel = 'SY0' + sync_data = self.get_analogsignal_chunk(channel_names = [channel], + stream_index = stream_index, + seg_index = seg_index) + #uint16 word to uint8 bytes to bits + sync_data_uint8 = sync_data.view(np.uint8) + unpacked_sync_data = np.unpackbits(sync_data_uint8, axis=1) + sync_line = unpacked_sync_data[:,1] + timestamps, _, labels = self._find_events_in_channel(sync_line, channel) + elif 'nidq' in stream_name: + #find channel from metafile + meta = info['meta'] + niChanType = int(meta['syncNiChanType']) + niChan = int(meta['syncNiChan']) + if niChanType == 0: #digital channel + timestamps, _, labels = self._get_event_timestamps(0, seg_index, niChan) + elif niChanType == 1: #analog channel + niThresh = float(meta['syncNiThresh']) #volts + sync_line = self.get_analogsignal_chunk(channel_names = [f'XA{niChan}'], + stream_index = stream_index, + seg_index = seg_index) + #Does this need to be scaled by channel gain before threshold? + sync_line = sync_line > niThresh + raise NotImplementedError("Analog sync events not yet implemented") + timestamps, _, labels = self._find_events_in_channel(sync_line, f'XA{niChan}') + else: + raise ValueError(f"Unknown stream type '{stream_name}', cannot find sync events") + return (timestamps, labels) - def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index): - info = self.signals_info_dict[0, "nidq"] # There are no events that are not in the nidq stream + def _find_events_in_channel(self, channel_data, channel_name): + ''' + Finds rising and falling edges in channel_data and returns + timestamps (in samples), duration (not implemented) and label with channel_name + ''' + + timestamps, durations, labels = [], None, [] + + this_rising = np.where(np.diff(channel_data) == 1)[0] + 1 + this_falling = ( + np.where(np.diff(channel_data) == 255)[0] + 1 + ) # because the data is in unsigned 8 bit, -1 = 255! + if len(this_rising) > 0: + timestamps.extend(this_rising) + labels.extend([f"{channel_name} ON"] * len(this_rising)) + if len(this_falling) > 0: + timestamps.extend(this_falling) + labels.extend([f"{channel_name} OFF"] * len(this_falling)) + timestamps = np.asarray(timestamps) + if len(labels) == 0: + labels = np.asarray(labels, dtype="U1") + else: + labels = np.asarray(labels) + return timestamps, durations, labels + + def _rescale_event_timestamp(self, event_timestamps, dtype=np.float64, event_channel_index=0, + stream_index = 0): + stream_name = self.header["signal_streams"][stream_index]["name"] + info = self.signals_info_dict[0, stream_name] # get sampling rate from first segment event_times = event_timestamps.astype(dtype) / float(info["sampling_rate"]) return event_times