diff --git a/neuro_data/static_images/configs.py b/neuro_data/static_images/configs.py index 2e63bbf..9e5a36f 100644 --- a/neuro_data/static_images/configs.py +++ b/neuro_data/static_images/configs.py @@ -198,7 +198,9 @@ def get_loaders(self, datasets, tier, batch_size, stimulus_types, Sampler): return loaders def load_data(self, key, tier=None, batch_size=1, key_order=None, - exclude_from_normalization=None, stimulus_types=None, Sampler=None): + stimulus_types=None, Sampler=None, **kwargs): + stimulus_types = key.pop('stimulus_type') + exclude = key.pop('exclude').split(',') log.info('Loading {} dataset with tier={}'.format( self._stimulus_type, tier)) datasets = StaticMultiDataset().fetch_data(key, key_order=key_order) @@ -211,7 +213,7 @@ def load_data(self, key, tier=None, batch_size=1, key_order=None, log.info('Using statistics source ' + key['stats_source']) datasets = self.add_transforms( - key, datasets, exclude=exclude_from_normalization) + key, datasets, exclude=exclude) loaders = self.get_loaders( datasets, tier, batch_size, stimulus_types, Sampler) @@ -221,10 +223,7 @@ class AreaLayerRawMixin(StimulusTypeMixin): def load_data(self, key, tier=None, batch_size=1, key_order=None, stimulus_types=None, Sampler=None, **kwargs): log.info('Ignoring input arguments: "' + '", "'.join(kwargs.keys()) + '"' + 'when creating datasets') - exclude = key.pop('exclude').split(',') - stimulus_types = key.pop('stimulus_type') datasets, loaders = super().load_data(key, tier, batch_size, key_order, - exclude_from_normalization=exclude, stimulus_types=stimulus_types, Sampler=Sampler) @@ -401,6 +400,31 @@ def load_data(self, key, cuda=False, oracle=False, **kwargs): return datasets, loaders + class StimulusType(dj.Part, StimulusTypeMixin): + definition = """ # stimulus type + -> master + --- + stats_source : varchar(50) # normalization source + stimulus_type : varchar(50) # type of stimulus + exclude : varchar(512) # what inputs to exclude from normalization + normalize : bool # whether to use a normalizer or not + normalize_per_image : bool # whether to normalize each input separately + """ + + def describe(self, key): + return "Stimulus type {stimulus_type}. normalize={normalize} on {stats_source} (except '{exclude}')".format( + **key) + + @property + def content(self): + for p in product(['all'], + ['stimulus.Frame'], + [''], + [True], + [False], + ): + yield dict(zip(self.heading.secondary_attributes, p)) + class CorrectedAreaLayer(dj.Part, AreaLayerRawMixin): definition = """ -> master diff --git a/neuro_data/static_images/data_schemas.py b/neuro_data/static_images/data_schemas.py index d3f378c..43fbd83 100644 --- a/neuro_data/static_images/data_schemas.py +++ b/neuro_data/static_images/data_schemas.py @@ -96,7 +96,11 @@ def make(self, key): self.insert(StaticScanCandidate & key, ignore_extra_fields=True) pipe = (fuse.ScanDone() & key).fetch1('pipe') pipe = dj.create_virtual_module(pipe, 'pipeline_' + pipe) - units = (fuse.ScanDone * pipe.ScanSet.Unit * pipe.MaskClassification.Type & key + # check segmentation compartment, if not soma, skip MaskClassification + if (shared.MaskType.proj(compartment='type') & (pipe.SegmentationTask & key)).fetch1('compartment') != 'soma': + units = (fuse.ScanDone * pipe.ScanSet.Unit & key & dict(pipe_version=1)) + else: + units = (fuse.ScanDone * pipe.ScanSet.Unit * pipe.MaskClassification.Type & key & dict(pipe_version=1, type='soma')) assert len(units) > 0, 'No units found!' self.Unit().insert(units, @@ -528,10 +532,16 @@ def load_traces_and_frametimes(self, key): ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k)) frame_times = (stimulus.Sync() & key).fetch1('frame_times').squeeze()[::ndepth] - soma = pipe.MaskClassification.Type() & dict(type='soma') + # if segmentation compartment is not soma, skip maskclassification + mask_type = (Preprocessing & key).fetch1('mask_type') + if mask_type == 'all': + compartment_restrict = {} + else: + assert mask_type in shared.MaskType, f'mask_type {mask_type} not found in shared.MaskType' + compartment_restrict = pipe.MaskClassification.Type() & dict(type=mask_type) spikes = (dj.U('field', 'channel') * pipe.Activity.Trace() * StaticScan.Unit() \ - * pipe.ScanSet.UnitInfo() & soma & key) + * pipe.ScanSet.UnitInfo() & compartment_restrict & key) traces, ms_delay, trace_keys = spikes.fetch('trace', 'ms_delay', dj.key, order_by='animal_id, session, scan_idx, unit_id') delay = np.fromiter(ms_delay / 1000, dtype=np.float)