diff --git a/neo/io/nixio_fr.py b/neo/io/nixio_fr.py index fe2f6ee8c..2a7d6d8cd 100644 --- a/neo/io/nixio_fr.py +++ b/neo/io/nixio_fr.py @@ -1,5 +1,6 @@ from neo.io.basefromrawio import BaseFromRaw from neo.rawio.nixrawio import NIXRawIO +import warnings # This class subjects to limitations when there are multiple asymmetric blocks @@ -11,10 +12,29 @@ class NixIO(NIXRawIO, BaseFromRaw): _prefered_signal_group_mode = 'group-by-same-units' _prefered_units_group_mode = 'all-in-one' - def __init__(self, filename): - NIXRawIO.__init__(self, filename) + def __init__(self, filename, block_index=0, autogenerate_stream_names=False, autogenerate_unit_ids=False): + NIXRawIO.__init__(self, filename, + block_index=block_index, + autogenerate_stream_names=autogenerate_stream_names, + autogenerate_unit_ids=autogenerate_unit_ids) BaseFromRaw.__init__(self, filename) + def read_block(self, block_index=0, **kwargs): + # sanity check to ensure constructed header and block to load match + if block_index != 0: + raise ValueError(f'Initialized IO for block {self.block_index}. ' + f'Can only read that block. Ignoring additional {block_index=} argument.') + + return super(NixIO, self).read_block(block_index=0, **kwargs) + + def read_segment(self, block_index=0, **kwargs): + # sanity check to ensure constructed header and block to load match + if block_index != 0: + raise ValueError(f'Initialized IO for block {self.block_index}.' + f'Can only read that block. Ignoring additional {block_index=} argument.') + + return super(NixIO, self).read_segment(block_index=0, **kwargs) + def __enter__(self): return self diff --git a/neo/rawio/nixrawio.py b/neo/rawio/nixrawio.py index 21f47cb24..7ce00a7fe 100644 --- a/neo/rawio/nixrawio.py +++ b/neo/rawio/nixrawio.py @@ -6,17 +6,15 @@ Author: Chek Yin Choi, Julia Sprenger """ - -import os.path +import warnings import numpy as np +from packaging.version import Version from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, _spike_channel_dtype, _event_channel_dtype) from ..io.nixio import check_nix_version - - # When reading metadata properties, the following keys are ignored since they # are used to store Neo object properties. # This dictionary is used in the _filter_properties() method. @@ -34,10 +32,54 @@ class NIXRawIO(BaseRawIO): extensions = ['nix', 'h5'] rawmode = 'one-file' - def __init__(self, filename=''): + def __init__(self, filename='', block_index=0, autogenerate_stream_names=False, autogenerate_unit_ids=False): + if autogenerate_stream_names: + warnings.warn('Automatically generating streams based on signal order in files. ' + 'Potentially overwriting stream names.') + if autogenerate_unit_ids: + warnings.warn('Automatically generating unit_ids. Ignoring stored unit information ' + 'and using order of spiketrains instead.' + 'Check for correct unit assignment.') check_nix_version() BaseRawIO.__init__(self) + + # checking consistency of generating neo version and autogeneration settings for reading + import nixio + nix_file = nixio.File.open(str(filename), nixio.FileMode.ReadOnly) + neo_generation_version = Version(nix_file.sections['neo'].props['version'].values[0]) + if neo_generation_version < Version('0.7.0') and not autogenerate_stream_names: + warnings.warn('Can load nix files generated by neo<0.7.0 only by autogenerating ' + 'stream names. Overwriting user setting `autogenerate_stream_names=False`') + autogenerate_stream_names = True + + if neo_generation_version < Version('0.10.0') and not autogenerate_unit_ids: + warnings.warn('Can load nix files generated by neo<0.7.0 only by autogenerating ' + 'unit ids. Overwriting user setting `autogenerate_unit_ids=False`') + autogenerate_unit_ids = True + + print(f'{filename=}\t{neo_generation_version=}') + + + self.filename = str(filename) + self.autogenerate_stream_names = autogenerate_stream_names + self.autogenerate_unit_ids = autogenerate_unit_ids + self.block_index = block_index + + @staticmethod + def get_block_count(filename): + """ + Retrieve the number of Blocks present in the nix file + + Returns: + (int) The number of blocks in the file. + """ + import nixio + nix_file = nixio.File.open(filename, nixio.FileMode.ReadOnly) + block_count = len(nix_file.blocks) + nix_file.close() + + return block_count def _source_name(self): return self.filename @@ -46,182 +88,266 @@ def _parse_header(self): import nixio self.file = nixio.File.open(self.filename, nixio.FileMode.ReadOnly) - signal_channels = [] - anasig_ids = {0: []} # ids of analogsignals by segment - stream_ids = [] - for bl in self.file.blocks: - for seg in bl.groups: - for da_idx, da in enumerate(seg.data_arrays): - if da.type == "neo.analogsignal": - chan_id = da_idx - ch_name = da.metadata['neo_name'] - units = str(da.unit) - dtype = str(da.dtype) - sr = 1 / da.dimensions[0].sampling_interval - anasig_id = da.name.split('.')[-2] - if anasig_id not in anasig_ids[0]: - anasig_ids[0].append(anasig_id) - stream_id = anasig_ids[0].index(anasig_id) - if stream_id not in stream_ids: - stream_ids.append(stream_id) - gain = 1 - offset = 0. - signal_channels.append((ch_name, chan_id, sr, dtype, - units, gain, offset, stream_id)) - # only read structure of first segment and assume the same - # across segments - break - break - signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) - signal_streams = np.zeros(len(stream_ids), dtype=_signal_stream_dtype) - signal_streams['id'] = stream_ids - signal_streams['name'] = '' - - spike_channels = [] - unit_name = "" - unit_id = "" - for bl in self.file.blocks: - seg_groups = [g for g in bl.groups if g.type == "neo.segment"] - - for seg in seg_groups: - for mt in seg.multi_tags: - if mt.type == "neo.spiketrain": - unit_name = mt.metadata['neo_name'] - unit_id = mt.id - wf_left_sweep = 0 - wf_units = None - wf_sampling_rate = 0 - if mt.features: - wf = mt.features[0].data - wf_units = wf.unit - dim = wf.dimensions[2] - interval = dim.sampling_interval - wf_sampling_rate = 1 / interval - if wf.metadata: - wf_left_sweep = wf.metadata["left_sweep"] - wf_gain = 1 - wf_offset = 0. - spike_channels.append( - (unit_name, unit_id, wf_units, wf_gain, - wf_offset, wf_left_sweep, wf_sampling_rate) - ) - break - break - spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype) + stream_name_by_id = {} + + self.nix_block = self.file.blocks[self.block_index] + segment_groups = [g for g in self.nix_block.groups if g.type == "neo.segment"] + neo_group_groups = [g for g in self.nix_block.groups if g.type == "neo.group"] + + def assert_channel_consistency(channels_by_segment): + reference_channels = np.asarray(channels_by_segment[0]) + if not self.autogenerate_stream_names: + # include name fields in stream comparison + name_mask = list(reference_channels.dtype.fields.keys()) + else: + # remove name fields from comparison + name_mask = [k for k in reference_channels.dtype.fields.keys() if k != 'name'] + + msg = 'Inconsistency across Segments: Try loading another block or use the ' \ + 'neo.io.nixio.NixIO for loading the nix file.' + + for segment_channels in channels_by_segment[1:]: + # compare channel numbers + if not len(segment_channels) == len(reference_channels): + raise ValueError(msg + f' Inconsistent number of channels.\n' + f'{len(segment_channels)} != ' + f'{len(reference_channels)}') + + # compare channel details + if not np.array_equal(segment_channels[name_mask], reference_channels[name_mask]): + raise ValueError(msg + f' Channels specifications are inconsistent:\n' + f'{segment_channels} differs from ' + f'{reference_channels} ') + + def data_array_to_signal_channel(chan_id, stream_id, da): + assert da.type == "neo.analogsignal" + ch_name = da.metadata['neo_name'] + units = str(da.unit) + dtype = str(da.dtype) + # TODO: The sampling_interval unit is not taken into account for reading... + sr = 1 / da.dimensions[0].sampling_interval + gain = 1 + offset = 0. + return (ch_name, chan_id, sr, dtype, units, gain, offset, stream_id) + + # construct signal channels for all segments + segments_signal_channels = [] + for seg_idx, seg in enumerate(segment_groups): + stream_id = -1 + last_anasig_id = 0 + signal_channels = [] + for da_idx, da in enumerate(seg.data_arrays): + if da.type == "neo.analogsignal": + # identify stream_id by common anasig id + anasig_id = da.name.split('.')[-2] + stream_name = da.metadata.props['neo_name'].values[0] + if anasig_id != last_anasig_id: + stream_id += 1 + last_anasig_id = anasig_id + if stream_id not in stream_name_by_id: + stream_name_by_id[stream_id] = stream_name + + # sanity check for stream_id <=> stream names association + + if not self.autogenerate_stream_names and stream_name_by_id[stream_id] != stream_name: + raise ValueError('Stream inconsistency across Segments or Blocks: ' + 'Try loading individual blocks or use the ' + 'neo.io.nixio.NixIO for loading the nix file.\n' + f'{stream_id=} with {stream_name=} does not match ' + f'{stream_name_by_id=}.') + + channel = data_array_to_signal_channel(da_idx, stream_id, da) + signal_channels.append(channel) + + signal_channels = np.asarray(signal_channels, dtype=_signal_channel_dtype) + segments_signal_channels.append(signal_channels) + + # verify consistency across blocks + assert_channel_consistency(segments_signal_channels) + + signal_streams = np.zeros(len(stream_name_by_id), dtype=_signal_stream_dtype) + signal_streams['id'] = list(stream_name_by_id.keys()) + signal_streams['name'] = list(stream_name_by_id.values()) + + def multi_tag_to_spike_channel(mt, unit_id): + assert mt.type == "neo.spiketrain" + unit_name = mt.metadata['neo_name'] + wf_left_sweep = 0 + wf_units = None + wf_sampling_rate = 0 + if mt.features: + wf = mt.features[0].data + wf_units = wf.unit + dim = wf.dimensions[2] + interval = dim.sampling_interval + wf_sampling_rate = 1 / interval + if wf.metadata: + wf_left_sweep = wf.metadata["left_sweep"] + wf_gain = 1 + wf_offset = 0. + return (unit_name, unit_id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate) + + segments_spike_channels = [] + + # detect neo groups that can be used to group spiketrains across segments + neo_spiketrain_groups = [] + for group in neo_group_groups: + # assume a group is a spiketrain `unit` when grouping only and as many spiketrains + # as segments. Is there a better way to check for this? + if (group.type == "neo.group" and + all([mt.type == "neo.spiketrain" for mt in group.multi_tags]) and + len(group.multi_tags) == len(segment_groups)): + neo_spiketrain_groups.append(group) + + self.unit_list_by_segment = [] + + for seg in segment_groups: + default_unit_id = 0 + spike_channels = [] + + st_dict = {'spiketrains': [], + # 'spiketrains_id': [], + 'spiketrains_unit_id': [], + 'waveforms': []} + self.unit_list_by_segment.append(st_dict) + + for mt in seg.multi_tags: + if mt.type != "neo.spiketrain": + continue + + if self.autogenerate_unit_ids: + unit_id = default_unit_id - event_channels = [] - event_count = 0 - epoch_count = 0 - for bl in self.file.blocks: - seg_groups = [g for g in bl.groups if g.type == "neo.segment"] - for seg in seg_groups: - for mt in seg.multi_tags: - if mt.type == "neo.event": - ev_name = mt.metadata['neo_name'] - ev_id = event_count - event_count += 1 - ev_type = "event" - event_channels.append((ev_name, ev_id, ev_type)) - if mt.type == "neo.epoch": - ep_name = mt.metadata['neo_name'] - ep_id = epoch_count - epoch_count += 1 - ep_type = "epoch" - event_channels.append((ep_name, ep_id, ep_type)) - break - break - event_channels = np.array(event_channels, dtype=_event_channel_dtype) - - self.da_list = {'blocks': []} - for block_index, blk in enumerate(self.file.blocks): - seg_groups = [g for g in blk.groups if g.type == "neo.segment"] - d = {'segments': []} - self.da_list['blocks'].append(d) - for seg_index, seg in enumerate(seg_groups): - d = {'signals': []} - self.da_list['blocks'][block_index]['segments'].append(d) - size_list = [] - data_list = [] - da_name_list = [] - for da in seg.data_arrays: - if da.type == 'neo.analogsignal': - size_list.append(da.size) - data_list.append(da) - da_name_list.append(da.metadata['neo_name']) - block = self.da_list['blocks'][block_index] - segment = block['segments'][seg_index] - segment['data_size'] = size_list - segment['data'] = data_list - segment['ch_name'] = da_name_list - - self.unit_list = {'blocks': []} - for block_index, blk in enumerate(self.file.blocks): - seg_groups = [g for g in blk.groups if g.type == "neo.segment"] - d = {'segments': []} - self.unit_list['blocks'].append(d) - for seg_index, seg in enumerate(seg_groups): - d = {'spiketrains': [], - 'spiketrains_id': [], - 'spiketrains_unit': []} - self.unit_list['blocks'][block_index]['segments'].append(d) - st_idx = 0 - for st in seg.multi_tags: - d = {'waveforms': []} - block = self.unit_list['blocks'][block_index] - segment = block['segments'][seg_index] - segment['spiketrains_unit'].append(d) - if st.type == 'neo.spiketrain': - segment['spiketrains'].append(st.positions) - segment['spiketrains_id'].append(st.id) - wftypestr = "neo.waveforms" - if (st.features and st.features[0].data.type == wftypestr): - waveforms = st.features[0].data - stdict = segment['spiketrains_unit'][st_idx] - if waveforms: - stdict['waveforms'] = waveforms - else: - stdict['waveforms'] = None - # assume one spiketrain one waveform - st_idx += 1 + else: + # files generated with neo <0.10.0: extract or define unit id of spiketrain from nix sources + nix_generation_neo_version = Version(self.file.sections['neo'].props['version'].values[0]) + if nix_generation_neo_version < Version('0.10.0'): + unit_sources = [s for s in mt.sources if s.type == 'neo.unit'] + + if len(unit_sources) == 1: + unit_id = unit_sources[0].name + + elif len(unit_sources) == 0: + warnings.warn('No unit information found. Using default unit id.') + unit_id = default_unit_id + + elif len(unit_sources) != 1: + raise ValueError('Ambiguous or missing unit assignment detected. ' + 'Use `autogenerate_unit_ids=True` to ignore ' + 'unit_ids in nix file and regenerate new ids.') + + + # files generated with recent neo versions use groups to groups spiketrains + elif nix_generation_neo_version >= Version('0.10.0'): + unit_groups = [g for g in neo_spiketrain_groups if mt in g.multi_tags] + + if len(unit_groups) == 1: + unit_id = unit_groups[0].metadata.props['neo_name'].values[0] + + elif len(unit_groups) == 0: + warnings.warn('No unit information found. Using default unit id.') + unit_id = default_unit_id + + elif len(unit_groups) > 1: + raise ValueError('Ambiguous or missing unit assignment detected. ' + 'Use `autogenerate_unit_ids=True` to ignore ' + 'unit_ids in nix file and regenerate new ids.') + + # register spiketrain data for faster data retrieval later on + st_dict['spiketrains'].append(mt.positions) + st_dict['spiketrains_unit_id'].append(unit_id) + if mt.features and mt.features[0].data.type == "neo.waveforms": + if mt.features[0].data: + waveforms = mt.features[0].data + else: + waveforms = None + + st_dict['waveforms'].append(waveforms) + + spike_channels.append(multi_tag_to_spike_channel(mt, unit_id)) + default_unit_id += 1 + + spike_channels = np.asarray(spike_channels, dtype=_spike_channel_dtype) + segments_spike_channels.append(spike_channels) + + + # verify consistency across segments + assert_channel_consistency(segments_spike_channels) + + segments_event_channels = [] + for seg in segment_groups: + event_count = 0 + epoch_count = 0 + event_channels = [] + for mt in seg.multi_tags: + if mt.type == "neo.event": + ev_name = mt.metadata['neo_name'] + ev_id = event_count + event_count += 1 + ev_type = "event" + event_channels.append((ev_name, ev_id, ev_type)) + if mt.type == "neo.epoch": + ep_name = mt.metadata['neo_name'] + ep_id = epoch_count + epoch_count += 1 + ep_type = "epoch" + event_channels.append((ep_name, ep_id, ep_type)) + event_channels = np.asarray(event_channels, dtype=_event_channel_dtype) + segments_event_channels.append(event_channels) + + assert_channel_consistency(segments_event_channels) + + # precollecting data array information + self.da_list_by_segments = [] + for seg_index, seg in enumerate(segment_groups): + st_dict = {'signals': []} + self.da_list_by_segments.append(st_dict) + size_list = [] + data_list = [] + da_name_list = [] + for da in seg.data_arrays: + if da.type == 'neo.analogsignal': + size_list.append(da.size) + data_list.append(da) + da_name_list.append(da.metadata['neo_name']) + st_dict['data_size'] = size_list + st_dict['data'] = data_list + st_dict['ch_name'] = da_name_list self.header = {} - self.header['nb_block'] = len(self.file.blocks) - self.header['nb_segment'] = [ - len(seg_groups) - for bl in self.file.blocks - ] + self.header['nb_block'] = 1 + self.header['nb_segment'] = [len(segment_groups)] self.header['signal_streams'] = signal_streams - self.header['signal_channels'] = signal_channels - self.header['spike_channels'] = spike_channels - self.header['event_channels'] = event_channels + # use signal, spike and event channels of first segments as these are consistent across segments + self.header['signal_channels'] = segments_signal_channels[0] + self.header['spike_channels'] = segments_spike_channels[0] + self.header['event_channels'] = segments_event_channels[0] self._generate_minimal_annotations() - for blk_idx, blk in enumerate(self.file.blocks): - seg_groups = [g for g in blk.groups if g.type == "neo.segment"] - bl_ann = self.raw_annotations['blocks'][blk_idx] - props = blk.metadata.inherited_properties() - bl_ann.update(self._filter_properties(props, "block")) - for grp_idx, group in enumerate(seg_groups): - seg_ann = bl_ann['segments'][grp_idx] - props = group.metadata.inherited_properties() - seg_ann.update(self._filter_properties(props, "segment")) - - sp_idx = 0 - ev_idx = 0 - for mt in group.multi_tags: - if mt.type == 'neo.spiketrain' and seg_ann['spikes']: - st_ann = seg_ann['spikes'][sp_idx] + bl_ann = self.raw_annotations['blocks'][0] + props = self.nix_block.metadata.inherited_properties() + bl_ann.update(self._filter_properties(props, "block")) + for grp_idx, group in enumerate(segment_groups): + seg_ann = bl_ann['segments'][grp_idx] + props = group.metadata.inherited_properties() + seg_ann.update(self._filter_properties(props, "segment")) + + sp_idx = 0 + ev_idx = 0 + for mt in group.multi_tags: + if mt.type == 'neo.spiketrain' and seg_ann['spikes']: + st_ann = seg_ann['spikes'][sp_idx] + props = mt.metadata.inherited_properties() + st_ann.update(self._filter_properties(props, 'spiketrain')) + sp_idx += 1 + # if order is preserving, the annotations + # should go to the right place, need test + if mt.type == "neo.event" or mt.type == "neo.epoch": + if seg_ann['events'] != []: + event_ann = seg_ann['events'][ev_idx] props = mt.metadata.inherited_properties() - st_ann.update(self._filter_properties(props, 'spiketrain')) - sp_idx += 1 - # if order is preserving, the annotations - # should go to the right place, need test - if mt.type == "neo.event" or mt.type == "neo.epoch": - if seg_ann['events'] != []: - event_ann = seg_ann['events'][ev_idx] - props = mt.metadata.inherited_properties() - event_ann.update(self._filter_properties(props, 'event')) - ev_idx += 1 + event_ann.update(self._filter_properties(props, 'event')) + ev_idx += 1 # adding array annotations to analogsignals annotated_anasigs = [] @@ -257,14 +383,14 @@ def _parse_header(self): def _segment_t_start(self, block_index, seg_index): t_start = 0 - for mt in self.file.blocks[block_index].groups[seg_index].multi_tags: + for mt in self.nix_block.groups[seg_index].multi_tags: if mt.type == "neo.spiketrain": t_start = mt.metadata['t_start'] return t_start def _segment_t_stop(self, block_index, seg_index): t_stop = 0 - for mt in self.file.blocks[block_index].groups[seg_index].multi_tags: + for mt in self.nix_block.groups[seg_index].multi_tags: if mt.type == "neo.spiketrain": t_stop = mt.metadata['t_stop'] return t_stop @@ -274,9 +400,7 @@ def _get_signal_size(self, block_index, seg_index, stream_index): keep = self.header['signal_channels']['stream_id'] == stream_id channel_indexes, = np.nonzero(keep) ch_idx = channel_indexes[0] - block = self.da_list['blocks'][block_index] - segment = block['segments'][seg_index] - size = segment['data_size'][ch_idx] + size = self.da_list_by_segments[seg_index]['data_size'][ch_idx] return size # size is per signal, not the sum of all channel_indexes def _get_signal_t_start(self, block_index, seg_index, stream_index): @@ -284,8 +408,7 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index): keep = self.header['signal_channels']['stream_id'] == stream_id channel_indexes, = np.nonzero(keep) ch_idx = channel_indexes[0] - block = self.file.blocks[block_index] - das = [da for da in block.groups[seg_index].data_arrays] + das = [da for da in self.nix_block.groups[seg_index].data_arrays] da = das[ch_idx] sig_t_start = float(da.metadata['t_start']) return sig_t_start # assume same group_id always same t_start @@ -304,7 +427,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, i_stop = self.get_signal_size(block_index, seg_index, stream_index) raw_signals_list = [] - da_list = self.da_list['blocks'][block_index]['segments'][seg_index] + da_list = self.da_list_by_segments[seg_index] for idx in global_channel_indexes: da = da_list['data'][idx] raw_signals_list.append(da[i_start:i_stop]) @@ -316,7 +439,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, def _spike_count(self, block_index, seg_index, unit_index): count = 0 head_id = self.header['spike_channels'][unit_index][1] - for mt in self.file.blocks[block_index].groups[seg_index].multi_tags: + for mt in self.nix_block.groups[seg_index].multi_tags: for src in mt.sources: if mt.type == 'neo.spiketrain' and [src.type == "neo.unit"]: if head_id == src.id: @@ -325,8 +448,7 @@ def _spike_count(self, block_index, seg_index, unit_index): def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): - block = self.unit_list['blocks'][block_index] - segment = block['segments'][seg_index] + segment = self.unit_list_by_segment[seg_index] spike_dict = segment['spiketrains'] spike_timestamps = spike_dict[unit_index] spike_timestamps = np.transpose(spike_timestamps) @@ -345,8 +467,8 @@ def _rescale_spike_timestamp(self, spike_timestamps, dtype): def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop): # this must return a 3D numpy array (nb_spike, nb_channel, nb_sample) - seg = self.unit_list['blocks'][block_index]['segments'][seg_index] - waveforms = seg['spiketrains_unit'][unit_index]['waveforms'] + seg = self.unit_list_by_segment[seg_index] + waveforms = seg['waveforms'][unit_index] if not waveforms: return None raw_waveforms = np.array(waveforms) @@ -364,7 +486,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, def _event_count(self, block_index, seg_index, event_channel_index): event_count = 0 - segment = self.file.blocks[block_index].groups[seg_index] + segment = self.nix_block.groups[seg_index] for event in segment.multi_tags: if event.type == 'neo.event' or event.type == 'neo.epoch': if event_count == event_channel_index: @@ -377,10 +499,11 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): timestamp = [] labels = [] - durations = None + durations = [] if event_channel_index is None: raise IndexError - for mt in self.file.blocks[block_index].groups[seg_index].multi_tags: + segments = [g for g in self.nix_block.groups if g.type == 'neo.segment'] + for mt in segments[seg_index].multi_tags: if mt.type == "neo.event" or mt.type == "neo.epoch": labels.append(mt.positions.dimensions[0].labels) po = mt.positions @@ -389,24 +512,30 @@ def _get_event_timestamps(self, block_index, seg_index, channel = self.header['event_channels'][event_channel_index] if channel['type'] == b'epoch' and mt.extents: if mt.extents.type == 'neo.epoch.durations': - durations = np.array(mt.extents) - break + durations.append(np.array(mt.extents)) + else: + durations.append(None) timestamp = timestamp[event_channel_index][:] timestamp = np.array(timestamp, dtype="float") + durations = durations[event_channel_index] labels = labels[event_channel_index][:] labels = np.array(labels, dtype='U') if t_start is not None: keep = timestamp >= t_start timestamp, labels = timestamp[keep], labels[keep] + if durations is not None: + durations = durations[keep] if t_stop is not None: keep = timestamp <= t_stop timestamp, labels = timestamp[keep], labels[keep] + if durations is not None: + durations = durations[keep] return timestamp, durations, labels # only the first fits in rescale def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index): ev_unit = '' - for mt in self.file.blocks[0].groups[0].multi_tags: + for mt in self.nix_block.groups[0].multi_tags: if mt.type == "neo.event": ev_unit = mt.positions.unit break @@ -418,7 +547,7 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index) def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index): ep_unit = '' - for mt in self.file.blocks[0].groups[0].multi_tags: + for mt in self.nix_block.groups[0].multi_tags: if mt.type == "neo.epoch": ep_unit = mt.positions.unit break diff --git a/neo/test/iotest/test_nixio_fr.py b/neo/test/iotest/test_nixio_fr.py index d57b18ae2..86960266f 100644 --- a/neo/test/iotest/test_nixio_fr.py +++ b/neo/test/iotest/test_nixio_fr.py @@ -1,14 +1,17 @@ """ Tests of neo.io.nixio_fr """ -import numpy as np import unittest -from quantities import s -from neo.io.nixio_fr import NixIO as NixIOfr + +import numpy as np import quantities as pq +from quantities import s + +from neo.core import Block, Segment, AnalogSignal, SpikeTrain, Event from neo.io.nixio import NixIO +from neo.io.nixio_fr import NixIO as NixIOfr from neo.test.iotest.common_io_test import BaseTestIO -from neo.core import Block, Segment, AnalogSignal, SpikeTrain, Event +from neo.test.tools import assert_same_sub_schema try: import nixio as nix @@ -24,18 +27,22 @@ class TestNixfr(BaseTestIO, unittest.TestCase, ): ioclass = NixIOfr entities_to_download = [ - 'nix/nixio_fr.nix' + 'nix/' ] entities_to_test = [ - 'nix/nixio_fr.nix' + # for BaseIO Tests use a rawio compatible file, that does not require special flags to be + # set for loading + 'nix/nix_rawio_compatible.nix' ] def setUp(self): super().setUp() + self.testfilename = self.get_local_path('nix/nixio_fr.nix') - self.reader_fr = NixIOfr(filename=self.testfilename) + self.reader_fr = NixIOfr(filename=self.testfilename, autogenerate_stream_names=True, + block_index=1) self.reader_norm = NixIO(filename=self.testfilename, mode='ro') - self.blk = self.reader_fr.read_block(block_index=1, load_waveforms=True) + self.blk = self.reader_fr.read_block(load_waveforms=True) # read block with NixIOfr self.blk1 = self.reader_norm.read_block(index=1) # read same block with NixIO @@ -105,17 +112,17 @@ def test_annotations(self): bl = Block(**annotations) annotations = {'something': 'hello hello000'} seg = Segment(**annotations) - an =AnalogSignal([[1, 2, 3], [4, 5, 6]], units='V', - sampling_rate=1 * pq.Hz) + an = AnalogSignal([[1, 2, 3], [4, 5, 6]], units='V', + sampling_rate=1 * pq.Hz) an.annotate(ansigrandom='hello chars') an.array_annotate(custom_id=[1, 2, 3]) - sp = SpikeTrain([3, 4, 5]* s, t_stop=10.0) + sp = SpikeTrain([3, 4, 5] * s, t_stop=10.0) sp.annotations['railway'] = 'hello train' - ev = Event(np.arange(0, 30, 10)*pq.Hz, + ev = Event(np.arange(0, 30, 10) * pq.Hz, labels=np.array(['trig0', 'trig1', 'trig2'], dtype='U')) ev.annotations['venue'] = 'hello event' ev2 = Event(np.arange(0, 30, 10) * pq.Hz, - labels=np.array(['trig0', 'trig1', 'trig2'], dtype='U')) + labels=np.array(['trig0', 'trig1', 'trig2'], dtype='U')) ev2.annotations['evven'] = 'hello ev' seg.spiketrains.append(sp) seg.events.append(ev) @@ -140,6 +147,34 @@ def test_annotations(self): os.remove(self.testfilename) +@unittest.skipUnless(HAVE_NIX, "Requires NIX") +class CompareTestFileVersions(BaseTestIO, unittest.TestCase): + ioclass = NixIOfr + entities_to_download = ['nix'] + entities_to_test = [] + + @classmethod + def setUpClass(cls): + super(CompareTestFileVersions, cls).setUpClass() + + cls.neo_versions = ['0.6.1', '0.7.2', '0.8.0', '0.9.0', '0.10.2', '0.11.1', '0.12.0'] + cls.blocks = [] + + for filename in [f'nix/generated_file_neo{ver}.nix' for ver in cls.neo_versions]: + filename = BaseTestIO.get_local_path(filename) + print(f'Loading {filename}') + + io = NixIOfr(filename, autogenerate_stream_names=False, autogenerate_unit_ids=False) + block = io.read_block(lazy=False) + cls.blocks.append(block) + io.file.close() + + def test_compare_file_versions(self): + # assert all versions result in comparable neo structures (ideally identical) + reference_block = self.blocks[0] + for bl in self.blocks[1:]: + assert_same_sub_schema(reference_block, bl, exclude=['file_origin', 'magnitude']) + if __name__ == '__main__': unittest.main() diff --git a/neo/test/rawiotest/rawio_compliance.py b/neo/test/rawiotest/rawio_compliance.py index a40cb5d27..9db3678bb 100644 --- a/neo/test/rawiotest/rawio_compliance.py +++ b/neo/test/rawiotest/rawio_compliance.py @@ -223,18 +223,18 @@ def read_analogsignals(reader): # read 500ms with several chunksize sr = reader.get_signal_sampling_rate(stream_index=stream_index) - lenght_to_read = int(.5 * sr) - if lenght_to_read < sig_size: + length_to_read = int(.5 * sr) + if length_to_read and length_to_read < sig_size: ref_raw_sigs = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index, i_start=0, - i_stop=lenght_to_read, + i_stop=length_to_read, stream_index=stream_index, channel_indexes=channel_indexes) for chunksize in (511, 512, 513, 1023, 1024, 1025): i_start = 0 chunks = [] - while i_start < lenght_to_read: - i_stop = min(i_start + chunksize, lenght_to_read) + while i_start < length_to_read: + i_stop = min(i_start + chunksize, length_to_read) raw_chunk = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index, i_start=i_start, i_stop=i_stop, diff --git a/neo/test/rawiotest/test_nixrawio.py b/neo/test/rawiotest/test_nixrawio.py index 9c57fa2d1..1bf495067 100644 --- a/neo/test/rawiotest/test_nixrawio.py +++ b/neo/test/rawiotest/test_nixrawio.py @@ -3,17 +3,18 @@ from neo.test.rawiotest.common_rawio_test import BaseTestRawIO -testfname = "" - - class TestNixRawIO(BaseTestRawIO, unittest.TestCase): rawioclass = NIXRawIO entities_to_download = [ - 'nix/nixrawio-1.5.nix' + 'nix' ] entities_to_test = [ 'nix/nixrawio-1.5.nix' ] + + nix_versions = ['0.6.1', '0.7.2', '0.8.0', '0.9.0', '0.10.2', '0.11.1', '0.12.0'] + nix_version_testfiles = [f'nix/generated_file_neo{ver}.nix' for ver in nix_versions] + entities_to_test.extend(nix_version_testfiles) if __name__ == "__main__":