diff --git a/aopy/data/base.py b/aopy/data/base.py index 61497a6e..5c87677d 100644 --- a/aopy/data/base.py +++ b/aopy/data/base.py @@ -14,6 +14,7 @@ import warnings import yaml import pandas as pd +import datetime # importlib_resources is a backport of importlib.resources from Python 3.9 if sys.version_info >= (3,9): @@ -49,6 +50,67 @@ def get_filenames_in_dir(base_dir, te): files[system] = filename return files +def get_te_number(file_name): + ''' + Extracts TE number from a file name. + + Args: + file_names (str): a single file names + + Returns: + int: TE number + + ''' + return int(file_name.split('_te')[1].split('.')[0]) + +def load_raw_hdf(data_dir, prefix, start_date_str): + ''' + Retrieves files in a specified directory that share a common prefix (usually the + name of the subject) and have been created since a specified start date. The files + are then arranged based on their TE number, and the number of days between each + consecutive file is calculated. + + Args: + data_dir (str): directory where the files will be + prefix (str): prefix of files to extract, usually the name of the monkey + subject (i.e.: chur, beig, etc.) + start_date_str (str): start date of files to extract, formmatted as + YYYMMDD + + Returns: + pd dataframe: dataframe of file information, including file names and days + since previous file + list: a list of sorted file names in the provided directory + ''' + file_names = [f for f in os.listdir(data_dir) if f.startswith(prefix+'20')] + + file_names_filtered = [] + for name in file_names: + date_str = name.split('_')[0].replace(prefix, '') + if date_str >= start_date_str: + file_names_filtered.append(name) + + file_names_sorted = sorted(file_names_filtered, key=get_te_number) + + print(f'{(len(file_names_sorted))} files parsed.') + + dates = [name.split('_')[0].replace(prefix, '') for name in file_names_sorted] + dates = [datetime.strptime(d, '%Y%m%d').date() for d in dates] + prev_date = None + deltas = [] + + for date in dates: + if prev_date: + delta = (date - prev_date).days + else: + delta = 0 + deltas.append(delta) + prev_date = date + + df = pd.DataFrame({'File Name': file_names_sorted, 'Days Since Prev': deltas}) + + return df, file_names_sorted + def get_preprocessed_filename(subject, te_id, date, data_source): ''' Generates preprocessed filenames as per our naming conventions. diff --git a/aopy/preproc/bmi3d.py b/aopy/preproc/bmi3d.py index d77257ee..822fe05c 100644 --- a/aopy/preproc/bmi3d.py +++ b/aopy/preproc/bmi3d.py @@ -723,4 +723,46 @@ def get_target_events(exp_data, exp_metadata): event_target = target_location[None,:] * target_on[:,None] target_events.append(event_target) - return np.array(target_events).transpose(1,0,2) \ No newline at end of file + return np.array(target_events).transpose(1,0,2) + +def segment(start_events, end_events, data): + ''' + Processes BMI3D events data and retrieves segments of events and corresponding timestamps + based on specified start and end events. + + This function is similar to get_trial_segments_and_times(), except that it is more + specialized to work with BMI3D events. + + Args: + start_events (list): list of start events to identify the beginning of each segment + end_events (list): list of end events to mark the end of each segment + data (dict): dictionary of BMI3D events data with keys 'bmi3d_events', 'code', and 'time'. + + Returns: + tuple: A tuple containing: + | **segments (list of lists of events):** A list of segments, each represented as a list of events. + | **segment_times (list of lists of times):** A list of corresponding timestamps for each event in the segments. + ''' + bmi3d_events = data['bmi3d_events'] + + event_code = bmi3d_events['code'] + event_inds = bmi3d_events['time'] + + evt_start_idx = np.where(np.in1d(event_code, start_events))[0] + + segments = [] + segment_times = [] + for idx_evt in range(len(evt_start_idx)): + idx_start = evt_start_idx[idx_evt] + idx_end = evt_start_idx[idx_evt] + 1 + + while idx_end < len(event_code): + if np.in1d(event_code[idx_end], start_events): + break + if np.in1d(event_code[idx_end], end_events): + segments.append(event_code[idx_start:idx_end+1]) + segment_times.append(event_inds[idx_start:idx_end+1]) + break + idx_end += 1 + + return segments, segment_times diff --git a/aopy/utils/tablet.py b/aopy/utils/tablet.py new file mode 100644 index 00000000..8c5ed5bb --- /dev/null +++ b/aopy/utils/tablet.py @@ -0,0 +1,70 @@ +import numpy as np +import re +import aopy +from aopy.preproc.bmi3d import segment + +TRIAL_START = 2 +REWARD = 48 + +def build_file_df(df, file_names_sorted, data_dir, traj = True): + ''' + Builds a dataframe from given files with metadata (runtime, reward trial length, + trajectory amplitude). Reports problematic files. + + Args: + df (pd dataframe): dataframe with a row for each file to process + file_names_sorted (list): list of file names, sorted by TE number + data_dir (str): directory where files are found + traj (bool): whether to include trajectory amplitude as a column + + Returns: + pd dataframe: dataframe with processed metadata information for each given + file, where each row is one file + ''' + + time2 = 5 + problem_flag = False + + print('Building dataset...') + + for i, f in enumerate(file_names_sorted): + files = dict(hdf=f) + data, metadata = aopy.preproc.bmi3d._parse_bmi3d_v1(data_dir, files) + bmi3d_task = data['bmi3d_task'] + + pattern = r'"runtime":\s*([0-9.]+)' + + try: + match = re.search(pattern, metadata['report']) + runtime = round(float(match.group(1))) + + except: + if not problem_flag: + print('Problematic files: \n') + problem_flag = True + print(f) + continue + + df.loc[i, 'Runtime'] = runtime + + if traj: + try: + df.loc[i, 'Trajectory Amplitude'] = np.ceil(max(abs(min(bmi3d_task['current_target'][:, 2])), + abs(max(bmi3d_task['current_target'][:, 2])))) + except: + if not problem_flag: + print('Problematic files: \n') + problem_flag = True + print(f) + continue + + # trial length + rewarded = segment([TRIAL_START], [REWARD], data)[1] + time = [] + for j in (rewarded): + time.append(j[-1] - j[0]) + time2 = round((np.median(time))/metadata['fps']) + + df.loc[i, 'Reward trial length'] = time2 + return df + diff --git a/tests/data/sample_hdfs/beig20210407_01_te1315.hdf b/tests/data/sample_hdfs/beig20210407_01_te1315.hdf new file mode 100644 index 00000000..80531d72 Binary files /dev/null and b/tests/data/sample_hdfs/beig20210407_01_te1315.hdf differ diff --git a/tests/data/sample_hdfs/beig20210614_07_te1825.hdf b/tests/data/sample_hdfs/beig20210614_07_te1825.hdf new file mode 100644 index 00000000..aba672e9 Binary files /dev/null and b/tests/data/sample_hdfs/beig20210614_07_te1825.hdf differ diff --git a/tests/data/sample_hdfs/beig20210929_02_te2949.hdf b/tests/data/sample_hdfs/beig20210929_02_te2949.hdf new file mode 100644 index 00000000..a23663f5 Binary files /dev/null and b/tests/data/sample_hdfs/beig20210929_02_te2949.hdf differ diff --git a/tests/data/sample_hdfs/beig20210930_02_te2952.hdf b/tests/data/sample_hdfs/beig20210930_02_te2952.hdf new file mode 100644 index 00000000..c861d555 Binary files /dev/null and b/tests/data/sample_hdfs/beig20210930_02_te2952.hdf differ diff --git a/tests/data/sample_hdfs/beig20221002_09_te6890.hdf b/tests/data/sample_hdfs/beig20221002_09_te6890.hdf new file mode 100644 index 00000000..b4f68cc8 Binary files /dev/null and b/tests/data/sample_hdfs/beig20221002_09_te6890.hdf differ diff --git a/tests/data/sample_hdfs/beig20230109_15_te7977.hdf b/tests/data/sample_hdfs/beig20230109_15_te7977.hdf new file mode 100644 index 00000000..3e9eb62b Binary files /dev/null and b/tests/data/sample_hdfs/beig20230109_15_te7977.hdf differ diff --git a/tests/data/sample_hdfs/chur20231002_02_te375.hdf b/tests/data/sample_hdfs/chur20231002_02_te375.hdf new file mode 100644 index 00000000..dd9c53df Binary files /dev/null and b/tests/data/sample_hdfs/chur20231002_02_te375.hdf differ diff --git a/tests/test_data.py b/tests/test_data.py index f06a0908..939066b9 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -292,6 +292,20 @@ def test_chanel_bank_name(self): self.assertEqual(ch_name, 'bottom') class HDFTests(unittest.TestCase): + + def test_get_te_number(self): + file_names = ['chur20231002_02_te375','beig20230109_15_te7977.hdf', 'beig20221002_09_te6890.hdf'] + te_ids = [375, 7977, 6890] + result = [get_te_number(file_name) for file_name in file_names] + self.assertEqual(result, te_ids) + + def test_load_raw_hdf(self): + hdf_dir = 'data/sample_hdfs' + os.path.join(data_dir, hdf_dir) + df, sorted_file_names = load_raw_hdf(hdf_dir, 'beig', '20220101') + correct_file_names = ['beig20221002_09_te6890.hdf', 'beig20230109_15_te7977.hdf'] + self.assertEqual(sorted_file_names, correct_file_names) + self.assertEqual(list(df['Days Since Prev']), [0, 99]) def test_save_hdf(self): import os