diff --git a/cognitive/info_generator.py b/cognitive/info_generator.py index cdc8d17..bb5b579 100644 --- a/cognitive/info_generator.py +++ b/cognitive/info_generator.py @@ -247,14 +247,14 @@ def get_examples(self) -> Tuple[List, Dict, Dict]: } return task_info_dict, compo, memory_trace_info - def write_trial_instance(self, write_fp: str, img_size=224, fixation_cue=True) -> None: + def write_trial_instance(self, write_fp: str, img_size=224, fixation_cue=True, train = True) -> None: frames_fp = os.path.join(write_fp, 'frames') if os.path.exists(frames_fp): shutil.rmtree(frames_fp) os.makedirs(frames_fp) objset = self.frame_info.objset - for i, (epoch, frame) in enumerate(zip(sg.render(objset, img_size), task_info.frame_info)): + for i, (epoch, frame) in enumerate(zip(sg.render(objset, img_size, train = train), self.frame_info)): if fixation_cue: if not any('ending' in description for description in frame.description): sg.add_fixation_cue(epoch) diff --git a/cognitive/stim_generator.py b/cognitive/stim_generator.py index c081731..ae5eaf0 100644 --- a/cognitive/stim_generator.py +++ b/cognitive/stim_generator.py @@ -809,7 +809,7 @@ def render_static_obj(canvas, obj, img_size, train=True): canvas[x_offset:x_end, y_offset:y_end] = shape_net_obj -def render_obj(canvas, obj, img_size): +def render_obj(canvas, obj, img_size, train = True): """Render a single object. Args: @@ -819,9 +819,9 @@ def render_obj(canvas, obj, img_size): img_size: int, image size. """ if isinstance(obj, StaticObject): - render_static_obj(canvas, obj, img_size) + render_static_obj(canvas, obj, img_size, train = train) else: - render_static_obj(canvas, obj.to_static()[0], img_size) + render_static_obj(canvas, obj.to_static()[0], img_size, train = train) def render_static(objlists, img_size=224, save_name=None): @@ -879,7 +879,7 @@ def render_static(objlists, img_size=224, save_name=None): return movie -def render(objsets, img_size=224, save_name=None): +def render(objsets, img_size=224, save_name=None, train = True): """Render a movie by epoch. Args: @@ -907,7 +907,7 @@ def render(objsets, img_size=224, save_name=None): subset = objset.select_now(epoch_now) for obj in subset: - render_obj(canvas, obj, img_size) + render_obj(canvas, obj, img_size, train = train) i_frame += 1 if save_name is not None: diff --git a/main.py b/main.py index 30ec34f..5dd1fe0 100644 --- a/main.py +++ b/main.py @@ -181,7 +181,11 @@ def generate_dataset( validation_examples -= 1 fname = os.path.join(validation_fname, f'{i}') - info.write_trial_instance(fname, img_size, fixation_cue) + # xlei: pass train/val parameter to write trial_instances + if i < int(total_examples * train): + info.write_trial_instance(fname, img_size, fixation_cue, train = True) + else: + info.write_trial_instance(fname, img_size, fixation_cue, train=False) i += 1 else: i = 0