Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions molmo_spaces/evaluation/eval_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def run_evaluation(
task_horizon=resolved_task_horizon,
num_workers=num_workers,
)
exp_config.task_type = "pick_and_place"

# Patch config with evaluation-specific runtime parameters
exp_config = JsonEvalRunner.patch_config(
Expand Down
29 changes: 21 additions & 8 deletions molmo_spaces/policy/learned_policy/pi_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

SAMPLE_PROMPTS = False

class PI_Policy(InferencePolicy):
def __init__(
Expand All @@ -21,11 +22,14 @@ def __init__(
) -> None:
super().__init__(exp_config, exp_config.task_type)
self.remote_config = exp_config.policy_config.remote_config
self.prompt_sampler = PromptSampler(
task_type=exp_config.task_type,
prompt_templates=exp_config.policy_config.prompt_templates,
prompt_object_word_num=exp_config.policy_config.prompt_object_word_num,
)
if SAMPLE_PROMPTS:
self.prompt_sampler = PromptSampler(
task_type=exp_config.task_type,
prompt_templates=exp_config.policy_config.prompt_templates,
prompt_object_word_num=exp_config.policy_config.prompt_object_word_num,
)
else:
self.prompt_sampler = None
self.checkpoint_path = exp_config.policy_config.checkpoint_path
self.grasping_type = exp_config.policy_config.grasping_type
self.chunk_size = exp_config.policy_config.chunk_size
Expand All @@ -35,7 +39,8 @@ def __init__(
def reset(self):
self.actions_buffer = None
self.current_buffer_index = 0
self.prompt_sampler.next()
if self.prompt_sampler is not None:
self.prompt_sampler.next()
self.starting_time = None

def prepare_model(self):
Expand Down Expand Up @@ -93,7 +98,10 @@ def render(self, obs):

def obs_to_model_input(self, obs):
# self.render(obs)
prompt = self.prompt_sampler.get_prompt(self.task).lower()
if self.prompt_sampler is None:
prompt = self.task.get_task_description()
else:
prompt = self.prompt_sampler.get_prompt(self.task).lower()

grip = np.clip(obs["qpos"]["gripper"][0] / 0.824033, 0, 1)
exo_camera_key = "droid_shoulder_light_randomization" if "droid_shoulder_light_randomization" in obs else "exo_camera_1"
Expand Down Expand Up @@ -151,7 +159,12 @@ def get_info(self) -> dict:
info["policy_buffer_length"] = self.chunk_size
info["policy_grasping_threshold"] = self.grasping_threshold
info["policy_grasping_type"] = self.grasping_type
info["prompt"] = self.prompt_sampler.get_prompt(self.task)
if self.prompt_sampler is not None:
info["prompt"] = self.prompt_sampler.get_prompt(self.task)
else:
info["prompt"] = self.task.get_task_description()
log.info(f"Current prompt: {info['prompt']}")

info["time_spent"] = time.time() - self.starting_time if self.starting_time else None
info["timestamp"] = time.time()
return info
1 change: 0 additions & 1 deletion molmo_spaces/policy/learned_policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def get_prompt(self, task: BaseMujocoTask) -> str:
else:
self._cached_prompt = self.prompt_templates[self.current_index].format(object_name)

log.info(f"The prompt is: {self._cached_prompt}")
return self._cached_prompt

def clean_object_name(self, task: BaseMujocoTask) -> str:
Expand Down
10 changes: 9 additions & 1 deletion molmo_spaces/tasks/json_eval_task_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def _validate_episode_spec(self, spec: EpisodeSpec) -> None:
f"Episode spec missing 'language.task_description'. house_index={spec.house_index}"
)

def _infer_task_type(self, spec: EpisodeSpec) -> str:
@staticmethod
def _infer_task_type(spec: EpisodeSpec) -> str:
"""Infer task_type from episode spec.

Uses task_type from spec if available, otherwise infers from task_cls.
Expand All @@ -282,6 +283,8 @@ def _infer_task_type(self, spec: EpisodeSpec) -> str:
return task_type

# Infer from task_cls
log.warning("Deprecated spec missing 'task.task_type', inferring from 'task.task_cls'. ")

task_cls = spec.get_task_cls()
task_cls_to_type = {
"molmo_spaces.tasks.pick_task.PickTask": "pick",
Expand All @@ -290,6 +293,11 @@ def _infer_task_type(self, spec: EpisodeSpec) -> str:
"molmo_spaces.tasks.opening_tasks.RBY1DoorOpeningTask": "door_opening",
"molmo_spaces.tasks.nav_task.NavToObjTask": "nav_to_obj",
}
task_cls_to_type_mujoco_thor = {}
for k,v in task_cls_to_type.items():
task_cls_to_type_mujoco_thor[k.replace("molmo_spaces", "mujoco_thor")] = v # TODO(rose): forking branch
task_cls_to_type.update(task_cls_to_type_mujoco_thor)

if task_cls in task_cls_to_type:
return task_cls_to_type[task_cls]

Expand Down