Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions aopy/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings
import yaml
import pandas as pd
import datetime

# importlib_resources is a backport of importlib.resources from Python 3.9
if sys.version_info >= (3,9):
Expand Down Expand Up @@ -49,6 +50,67 @@ def get_filenames_in_dir(base_dir, te):
files[system] = filename
return files

def get_te_number(file_name):
'''
Extracts TE number from a file name.

Args:
file_names (str): a single file names

Returns:
int: TE number

'''
return int(file_name.split('_te')[1].split('.')[0])

def load_raw_hdf(data_dir, prefix, start_date_str):
'''
Retrieves files in a specified directory that share a common prefix (usually the
name of the subject) and have been created since a specified start date. The files
are then arranged based on their TE number, and the number of days between each
consecutive file is calculated.

Args:
data_dir (str): directory where the files will be
prefix (str): prefix of files to extract, usually the name of the monkey
subject (i.e.: chur, beig, etc.)
start_date_str (str): start date of files to extract, formmatted as
YYYMMDD

Returns:
pd dataframe: dataframe of file information, including file names and days
since previous file
list: a list of sorted file names in the provided directory
'''
file_names = [f for f in os.listdir(data_dir) if f.startswith(prefix+'20')]

file_names_filtered = []
for name in file_names:
date_str = name.split('_')[0].replace(prefix, '')
if date_str >= start_date_str:
file_names_filtered.append(name)

file_names_sorted = sorted(file_names_filtered, key=get_te_number)

print(f'{(len(file_names_sorted))} files parsed.')

dates = [name.split('_')[0].replace(prefix, '') for name in file_names_sorted]
dates = [datetime.strptime(d, '%Y%m%d').date() for d in dates]
prev_date = None
deltas = []

for date in dates:
if prev_date:
delta = (date - prev_date).days
else:
delta = 0
deltas.append(delta)
prev_date = date

df = pd.DataFrame({'File Name': file_names_sorted, 'Days Since Prev': deltas})

return df, file_names_sorted

def get_preprocessed_filename(subject, te_id, date, data_source):
'''
Generates preprocessed filenames as per our naming conventions.
Expand Down
44 changes: 43 additions & 1 deletion aopy/preproc/bmi3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,46 @@ def get_target_events(exp_data, exp_metadata):
event_target = target_location[None,:] * target_on[:,None]
target_events.append(event_target)

return np.array(target_events).transpose(1,0,2)
return np.array(target_events).transpose(1,0,2)

def segment(start_events, end_events, data):
'''
Processes BMI3D events data and retrieves segments of events and corresponding timestamps
based on specified start and end events.

This function is similar to get_trial_segments_and_times(), except that it is more
specialized to work with BMI3D events.

Args:
start_events (list): list of start events to identify the beginning of each segment
end_events (list): list of end events to mark the end of each segment
data (dict): dictionary of BMI3D events data with keys 'bmi3d_events', 'code', and 'time'.

Returns:
tuple: A tuple containing:
| **segments (list of lists of events):** A list of segments, each represented as a list of events.
| **segment_times (list of lists of times):** A list of corresponding timestamps for each event in the segments.
'''
bmi3d_events = data['bmi3d_events']

event_code = bmi3d_events['code']
event_inds = bmi3d_events['time']

evt_start_idx = np.where(np.in1d(event_code, start_events))[0]

segments = []
segment_times = []
for idx_evt in range(len(evt_start_idx)):
idx_start = evt_start_idx[idx_evt]
idx_end = evt_start_idx[idx_evt] + 1

while idx_end < len(event_code):
if np.in1d(event_code[idx_end], start_events):
break
if np.in1d(event_code[idx_end], end_events):
segments.append(event_code[idx_start:idx_end+1])
segment_times.append(event_inds[idx_start:idx_end+1])
break
idx_end += 1

return segments, segment_times
70 changes: 70 additions & 0 deletions aopy/utils/tablet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
import re
import aopy
from aopy.preproc.bmi3d import segment

TRIAL_START = 2
REWARD = 48

def build_file_df(df, file_names_sorted, data_dir, traj = True):
'''
Builds a dataframe from given files with metadata (runtime, reward trial length,
trajectory amplitude). Reports problematic files.

Args:
df (pd dataframe): dataframe with a row for each file to process
file_names_sorted (list): list of file names, sorted by TE number
data_dir (str): directory where files are found
traj (bool): whether to include trajectory amplitude as a column

Returns:
pd dataframe: dataframe with processed metadata information for each given
file, where each row is one file
'''

time2 = 5
problem_flag = False

print('Building dataset...')

for i, f in enumerate(file_names_sorted):
files = dict(hdf=f)
data, metadata = aopy.preproc.bmi3d._parse_bmi3d_v1(data_dir, files)
bmi3d_task = data['bmi3d_task']

pattern = r'"runtime":\s*([0-9.]+)'

try:
match = re.search(pattern, metadata['report'])
runtime = round(float(match.group(1)))

except:
if not problem_flag:
print('Problematic files: \n')
problem_flag = True
print(f)
continue

df.loc[i, 'Runtime'] = runtime

if traj:
try:
df.loc[i, 'Trajectory Amplitude'] = np.ceil(max(abs(min(bmi3d_task['current_target'][:, 2])),
abs(max(bmi3d_task['current_target'][:, 2]))))
except:
if not problem_flag:
print('Problematic files: \n')
problem_flag = True
print(f)
continue

# trial length
rewarded = segment([TRIAL_START], [REWARD], data)[1]
time = []
for j in (rewarded):
time.append(j[-1] - j[0])
time2 = round((np.median(time))/metadata['fps'])

df.loc[i, 'Reward trial length'] = time2
return df

Binary file added tests/data/sample_hdfs/beig20210407_01_te1315.hdf
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/sample_hdfs/beig20210930_02_te2952.hdf
Binary file not shown.
Binary file added tests/data/sample_hdfs/beig20221002_09_te6890.hdf
Binary file not shown.
Binary file not shown.
Binary file added tests/data/sample_hdfs/chur20231002_02_te375.hdf
Binary file not shown.
14 changes: 14 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,20 @@ def test_chanel_bank_name(self):
self.assertEqual(ch_name, 'bottom')

class HDFTests(unittest.TestCase):

def test_get_te_number(self):
file_names = ['chur20231002_02_te375','beig20230109_15_te7977.hdf', 'beig20221002_09_te6890.hdf']
te_ids = [375, 7977, 6890]
result = [get_te_number(file_name) for file_name in file_names]
self.assertEqual(result, te_ids)

def test_load_raw_hdf(self):
hdf_dir = 'data/sample_hdfs'
os.path.join(data_dir, hdf_dir)
df, sorted_file_names = load_raw_hdf(hdf_dir, 'beig', '20220101')
correct_file_names = ['beig20221002_09_te6890.hdf', 'beig20230109_15_te7977.hdf']
self.assertEqual(sorted_file_names, correct_file_names)
self.assertEqual(list(df['Days Since Prev']), [0, 99])

def test_save_hdf(self):
import os
Expand Down