Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions neuro_data/static_images/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def get_loaders(self, datasets, tier, batch_size, stimulus_types, Sampler):
return loaders

def load_data(self, key, tier=None, batch_size=1, key_order=None,
exclude_from_normalization=None, stimulus_types=None, Sampler=None):
stimulus_types=None, Sampler=None, **kwargs):
stimulus_types = key.pop('stimulus_type')
exclude = key.pop('exclude').split(',')
log.info('Loading {} dataset with tier={}'.format(
self._stimulus_type, tier))
datasets = StaticMultiDataset().fetch_data(key, key_order=key_order)
Expand All @@ -211,7 +213,7 @@ def load_data(self, key, tier=None, batch_size=1, key_order=None,
log.info('Using statistics source ' + key['stats_source'])

datasets = self.add_transforms(
key, datasets, exclude=exclude_from_normalization)
key, datasets, exclude=exclude)

loaders = self.get_loaders(
datasets, tier, batch_size, stimulus_types, Sampler)
Expand All @@ -221,10 +223,7 @@ class AreaLayerRawMixin(StimulusTypeMixin):
def load_data(self, key, tier=None, batch_size=1, key_order=None, stimulus_types=None, Sampler=None, **kwargs):
log.info('Ignoring input arguments: "' +
'", "'.join(kwargs.keys()) + '"' + 'when creating datasets')
exclude = key.pop('exclude').split(',')
stimulus_types = key.pop('stimulus_type')
datasets, loaders = super().load_data(key, tier, batch_size, key_order,
exclude_from_normalization=exclude,
stimulus_types=stimulus_types,
Sampler=Sampler)

Expand Down Expand Up @@ -401,6 +400,31 @@ def load_data(self, key, cuda=False, oracle=False, **kwargs):

return datasets, loaders

class StimulusType(dj.Part, StimulusTypeMixin):
definition = """ # stimulus type
-> master
---
stats_source : varchar(50) # normalization source
stimulus_type : varchar(50) # type of stimulus
exclude : varchar(512) # what inputs to exclude from normalization
normalize : bool # whether to use a normalizer or not
normalize_per_image : bool # whether to normalize each input separately
"""

def describe(self, key):
return "Stimulus type {stimulus_type}. normalize={normalize} on {stats_source} (except '{exclude}')".format(
**key)

@property
def content(self):
for p in product(['all'],
['stimulus.Frame'],
[''],
[True],
[False],
):
yield dict(zip(self.heading.secondary_attributes, p))

class CorrectedAreaLayer(dj.Part, AreaLayerRawMixin):
definition = """
-> master
Expand Down
16 changes: 13 additions & 3 deletions neuro_data/static_images/data_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def make(self, key):
self.insert(StaticScanCandidate & key, ignore_extra_fields=True)
pipe = (fuse.ScanDone() & key).fetch1('pipe')
pipe = dj.create_virtual_module(pipe, 'pipeline_' + pipe)
units = (fuse.ScanDone * pipe.ScanSet.Unit * pipe.MaskClassification.Type & key
# check segmentation compartment, if not soma, skip MaskClassification
if (shared.MaskType.proj(compartment='type') & (pipe.SegmentationTask & key)).fetch1('compartment') != 'soma':
units = (fuse.ScanDone * pipe.ScanSet.Unit & key & dict(pipe_version=1))
else:
units = (fuse.ScanDone * pipe.ScanSet.Unit * pipe.MaskClassification.Type & key
& dict(pipe_version=1, type='soma'))
assert len(units) > 0, 'No units found!'
self.Unit().insert(units,
Expand Down Expand Up @@ -528,10 +532,16 @@ def load_traces_and_frametimes(self, key):
ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k))
frame_times = (stimulus.Sync() & key).fetch1('frame_times').squeeze()[::ndepth]

soma = pipe.MaskClassification.Type() & dict(type='soma')
# if segmentation compartment is not soma, skip maskclassification
mask_type = (Preprocessing & key).fetch1('mask_type')
if mask_type == 'all':
compartment_restrict = {}
else:
assert mask_type in shared.MaskType, f'mask_type {mask_type} not found in shared.MaskType'
compartment_restrict = pipe.MaskClassification.Type() & dict(type=mask_type)

spikes = (dj.U('field', 'channel') * pipe.Activity.Trace() * StaticScan.Unit() \
* pipe.ScanSet.UnitInfo() & soma & key)
* pipe.ScanSet.UnitInfo() & compartment_restrict & key)
traces, ms_delay, trace_keys = spikes.fetch('trace', 'ms_delay', dj.key,
order_by='animal_id, session, scan_idx, unit_id')
delay = np.fromiter(ms_delay / 1000, dtype=np.float)
Expand Down