diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f2bccc7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Andres Garcia + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 1100467..b417b87 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ Explore synthetic dataset generation using formal grammars: - **`custom_grammar_gen.py`** - Generate datasets using custom grammar configurations - **`filtered_dataset_generation.py`** - Advanced dataset generation with simulation-based filtering - **`mixed_gen.py`** - Mixed dataset generation combining multiple approaches +- **`rlhf_dataset_gen.py`** - Interactive RLHF-inspired dataset generation with human-in-the-loop feedback using a unified GUI #### **Experimental Setup Examples** (`examples/04_experimental_setup_example/`) Complete experimental pipeline example for research and evaluation: diff --git a/agent_control/simulation/agents/agent_sensors.py b/agent_control/simulation/agents/agent_sensors.py index 8036784..987a759 100644 --- a/agent_control/simulation/agents/agent_sensors.py +++ b/agent_control/simulation/agents/agent_sensors.py @@ -25,7 +25,7 @@ def sense_light(self) -> Vector2: Returns: Vector2: Direction vector to the nearest light source """ - light_pos = self.agent.env.light_pos + light_pos = self.agent.env.light_pos # type: ignore[attr-defined] agent_pos = self.agent.pos diff_vec: Vector2 = light_pos - agent_pos diff --git a/agent_control/simulation/agents/elements.py b/agent_control/simulation/agents/elements.py index 252cc6d..ba9c199 100644 --- a/agent_control/simulation/agents/elements.py +++ b/agent_control/simulation/agents/elements.py @@ -11,7 +11,7 @@ from ..envs.robot_env import RobotEnvironment -class Part(Agent): # type: ignore +class Part(Agent): """Part object that can be picked up and moved by robot agents.""" def __init__( @@ -38,7 +38,7 @@ def __init__( self.simulation = simulation self.type = type self.pos = pos - self.owner: Optional[Agent] = None + self.owner: Optional[Agent | "RobotEnvironment"] = None self.env = env self.is_permanently_placed = False self.update_img() @@ -50,11 +50,11 @@ def update_img(self) -> None: def remove_part(self) -> None: """Remove this part from the simulation.""" - self.kill() + self.kill() # type: ignore[no-untyped-call] def update(self) -> None: """Update the part's position and state.""" - if self.owner and not self.is_permanently_placed: + if self.owner and not self.is_permanently_placed and hasattr(self.owner, "pos"): self.pos = self.owner.pos def can_be_picked_up(self) -> bool: diff --git a/agent_control/simulation/agents/robot_agent.py b/agent_control/simulation/agents/robot_agent.py index b69db5e..467855a 100644 --- a/agent_control/simulation/agents/robot_agent.py +++ b/agent_control/simulation/agents/robot_agent.py @@ -5,7 +5,7 @@ import pygame as pg from pygame.math import Vector2 -from vi import Agent, Simulation +from vi import Agent, HeadlessSimulation, Simulation from tree_parser.middle_parser import parse_behavior_tree @@ -16,13 +16,13 @@ from ..envs.base_env import SimEnvironment -class RobotAgent(Agent): # type: ignore +class RobotAgent(Agent): """Robot agent that executes behavior trees in simulation environments.""" def __init__( self, images: List[pg.Surface], - simulation: Simulation, + simulation: Simulation | HeadlessSimulation, pos: Vector2, env: "SimEnvironment", xml_path: str, diff --git a/agent_control/simulation/envs/robot_env.py b/agent_control/simulation/envs/robot_env.py index ec06f28..fae9e6d 100644 --- a/agent_control/simulation/envs/robot_env.py +++ b/agent_control/simulation/envs/robot_env.py @@ -88,32 +88,32 @@ def draw_arena(self) -> None: """Draw arena boundaries in the simulation environment.""" self.simulation.spawn_obstacle( "./agent_control/simulation/images/arena_new.png", - self.arena_pos.x, - self.arena_pos.y, + int(self.arena_pos.x), + int(self.arena_pos.y), ) def draw_source(self) -> None: """Draw source area where good parts are located.""" self.simulation.spawn_site( "./agent_control/simulation/images/source_green.png", - self.source_pos.x, - self.source_pos.y, + int(self.source_pos.x), + int(self.source_pos.y), ) def draw_nest(self) -> None: """Draw base/nest area where parts should be delivered.""" self.simulation.spawn_site( "./agent_control/simulation/images/blue_nest.png", - self.base_pos.x, - self.base_pos.y, + int(self.base_pos.x), + int(self.base_pos.y), ) def draw_waste(self) -> None: """Draw waste area where bad parts should be disposed.""" self.simulation.spawn_site( "./agent_control/simulation/images/waste_red.png", - self.waste_pos.x, - self.waste_pos.y, + int(self.waste_pos.x), + int(self.waste_pos.y), ) def spawn_part(self, type: str, pos: Vector2) -> None: @@ -141,7 +141,7 @@ def remove_part(self, part: Part) -> None: Args: part: Part object to remove from simulation """ - part.kill() + part.kill() # type: ignore[no-untyped-call] def place_parts(self, num_parts: int) -> None: """ diff --git a/data_grammar/__init__.py b/data_grammar/__init__.py index dffa68a..b88242d 100644 --- a/data_grammar/__init__.py +++ b/data_grammar/__init__.py @@ -1,5 +1,6 @@ """Data Grammar package for generating behavior trees and datasets.""" from .dataset_generator import DatasetGenerator +from .rlhf_generation.rlhf_unified import UnifiedRLHFUI -__all__ = ["DatasetGenerator"] +__all__ = ["DatasetGenerator", "UnifiedRLHFUI"] diff --git a/data_grammar/dataset_generation/sys_prompt.py b/data_grammar/dataset_generation/sys_prompt.py index 99a00f0..6f5efdc 100644 --- a/data_grammar/dataset_generation/sys_prompt.py +++ b/data_grammar/dataset_generation/sys_prompt.py @@ -504,7 +504,7 @@ Metrics: -{"good_parts_picked_up": 1} +{{"good_parts_picked_up": 1}} Task: Your task is to find a good part and stop moving. @@ -528,7 +528,7 @@ Metrics: -{"good_parts_picked_up": 1} +{{"good_parts_picked_up": 1}} Task: Your task is to find a scrap part and stop moving. @@ -552,7 +552,7 @@ Metrics: -{"bad_parts_picked_up": 1} +{{"bad_parts_picked_up": 1}} -Examples for finding and bringing parts to areas: @@ -584,7 +584,7 @@ Metrics: -{"good_parts_picked_up": 1, "parts_dropped_in_construction": [1, 0]} +{{"good_parts_picked_up": 1, "parts_dropped_in_construction": [1, 0]}} Task: Your task is to find scrap parts and bring them to the source area. @@ -614,7 +614,7 @@ Metrics: -{"bad_parts_picked_up": 1, "parts_dropped_in_source": [0, 1]} +{{"bad_parts_picked_up": 1, "parts_dropped_in_source": [0, 1]}} - Complex example on how to handle both good and scrap parts to different areas: @@ -666,7 +666,7 @@ Metrics: -{"good_parts_picked_up": 1, "bad_parts_picked_up": 1, "parts_dropped_in_storage": [1, 0], "parts_dropped_in_source": [0, 1]} +{{"good_parts_picked_up": 1, "bad_parts_picked_up": 1, "parts_dropped_in_storage": [1, 0], "parts_dropped_in_source": [0, 1]}} You can use the examples above to help you understand how nodes come toguether to build functional units of a behaviour tree, but you can also take direct inspiration to populate a tree if the strucutre is the same or similar to an example. @@ -722,7 +722,7 @@ Metrics: -{"bad_parts_picked_up": 1} +{{"bad_parts_picked_up": 1}} """ diff --git a/data_grammar/dataset_generator.py b/data_grammar/dataset_generator.py index 839cb5b..476b1a1 100644 --- a/data_grammar/dataset_generator.py +++ b/data_grammar/dataset_generator.py @@ -27,7 +27,7 @@ # Add type: ignore for dotenv import since it's an optional dependency try: - from dotenv import load_dotenv # type: ignore + from dotenv import load_dotenv load_dotenv() except ImportError: diff --git a/data_grammar/rlhf_generation/__init__.py b/data_grammar/rlhf_generation/__init__.py new file mode 100644 index 0000000..ecc99e7 --- /dev/null +++ b/data_grammar/rlhf_generation/__init__.py @@ -0,0 +1 @@ +"""Module for generating datasets trough methods inspired by RLHF.""" diff --git a/data_grammar/rlhf_generation/output/dataset_path.json b/data_grammar/rlhf_generation/output/dataset_path.json new file mode 100644 index 0000000..deea0ff --- /dev/null +++ b/data_grammar/rlhf_generation/output/dataset_path.json @@ -0,0 +1,98 @@ +[ + { + "layman_task": "Find scrap parts and pick them up", + "technical_task": "if you detect a scrap part then pick up the part, or otherwise then search randomly.", + "spoon_task": "if you detect a scrap part then pick up the part, or otherwise then walk randomly.", + "tree": "\n \n \n is_scrap_part_detected\n pick_up_part\n \n state_random_walk\n \n" + }, + { + "layman_task": "Find good parts and pick them up", + "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n" + }, + { + "layman_task": "Find good parts and take them to the base", + "technical_task": "if you are holding a good part then go to the base, or if you detect a good part then pick up the part, or otherwise then go to the source.", + "spoon_task": "if you are holding a good part then seek the base area, or if you detect a good part then pick up the part, or otherwise then seek the source area.", + "tree": "\n \n \n is_agent_holding_good_part\n state_seek_base_area\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n" + }, + { + "layman_task": "Find scrap parts and take them to the waste", + "technical_task": "if you detect a scrap part then pick up the part, or if you are holding a scrap part then go to the waste, or if you are in the waste then drop the part, or otherwise then search randomly.", + "spoon_task": "if you detect a scrap part then pick up the part, or if you are holding a scrap part then seek the waste area, or if you are in the waste area then drop the part, or otherwise then walk randomly.", + "tree": "\n \n \n is_scrap_part_detected\n pick_up_part\n \n \n is_agent_holding_scrap_part\n state_seek_waste_area\n \n \n is_agent_in_waste_area\n drop_part\n \n state_random_walk\n \n" + }, + { + "layman_task": "Find good parts and pick them up", + "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n" + }, + { + "layman_task": "find good parts and pick them up", + "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n" + }, + { + "layman_task": "Find good parts and pick them up", + "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n" + }, + { + "layman_task": "Find scrap parts and pick them up", + "technical_task": "if you detect a scrap part then pick up the part, or otherwise then search randomly.", + "spoon_task": "if you detect a scrap part then pick up the part, or otherwise then walk randomly.", + "tree": "\n \n \n is_scrap_part_detected\n pick_up_part\n \n state_random_walk\n \n" + }, + { + "layman_task": "Find good parts and pick them up nigga", + "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n" + }, + { + "layman_task": "Find good parts and pick them up", + "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n" + }, + { + "layman_task": "Find good parts and pick them up", + "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n" + }, + { + "layman_task": "Find good parts and pick them up", + "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n" + }, + { + "layman_task": "Pick up scrap parts and pick them up", + "technical_task": "if you detect a scrap part then pick up the part, or otherwise then search randomly.", + "spoon_task": "if you detect a scrap part then pick up the part, or otherwise then walk randomly.", + "tree": "\n \n \n is_scrap_part_detected\n pick_up_part\n \n state_random_walk\n \n" + }, + { + "layman_task": "Find good parts and pick them up", + "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n" + }, + { + "layman_task": "Find good parts and pick them up", + "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.", + "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.", + "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n" + }, + { + "layman_task": "Find scrap parts and picm them up and take them to the waste area and drop them there", + "technical_task": "if you are holding a scrap part and if you are in the waste then drop the part, or if you are holding a scrap part then go to the waste, or otherwise if you detect a scrap part then pick up the part.", + "spoon_task": "if you are holding a scrap part and if you are in the waste area then drop the part, or if you are holding a scrap part then seek the waste area, or otherwise if you detect a scrap part then pick up the part.", + "tree": "\n \n \n is_agent_holding_scrap_part\n is_agent_in_waste_area\n drop_part\n \n \n is_agent_holding_scrap_part\n state_seek_waste_area\n \n \n is_scrap_part_detected\n pick_up_part\n \n \n" + } +] \ No newline at end of file diff --git a/data_grammar/rlhf_generation/rlhf_unified.py b/data_grammar/rlhf_generation/rlhf_unified.py new file mode 100644 index 0000000..d3c33ea --- /dev/null +++ b/data_grammar/rlhf_generation/rlhf_unified.py @@ -0,0 +1,1206 @@ +"""Unified RLHF-inspired dataset generation interface using Tkinter GUI.""" + +import os +import threading +import tkinter as tk +from tkinter import scrolledtext, ttk +from typing import Any, Dict, List, Optional, Type + +from dotenv import load_dotenv +from langchain_core.prompts import ChatPromptTemplate +from langchain_openai import ChatOpenAI +from pydantic import BaseModel, Field +from vi import Agent + +from agent_control import RobotAgent, RobotEnvironment +from data_grammar.rlhf_generation.utils.prompt_builder import PromptBuilder +from data_grammar.rlhf_generation.utils.run_robot_sim import run_robot_sim +from data_grammar.rlhf_generation.utils.save_data_point import save_datapoint +from tree_parser import AgentDocstringParser, BehaviorTreeGrammarValidator +from tree_parser.primitives_validator import validate_primitives + +load_dotenv() + +openai_api_key = os.getenv("OPENAI_API_KEY") + + +class GraphState(BaseModel): + """Represents the state of our graph.""" + + task_definition: str = Field( + description="The natural language definition of the task to be executed" + ) + task_metrics_goal: str = Field( + description="The user defined metrics the tree should achieve in the simulation" + ) + behaviour_tree: Optional[str] = Field( + default=None, description="The bahavior tree returned by the LLM" + ) + passed_validator: bool = Field( + default=False, + description="Whether the tree passed both the grammar and primitive validator checks", + ) + validator_feedback: Optional[str] = Field( + default=None, + description="If the tree doesnt pass, the feedback of both the grammar and primitive validator checks", + ) + task_metrics_result: Optional[Dict[str, Any]] = Field( + default=None, + description="The metrics returned by executing the tree on the simulator", + ) + human_feedback: Optional[str] = Field( + default=None, + description="The feedback of the human on the task execution that will be given to the LLM in case of retries", + ) + dataset_size: int = Field(description="The number of samples generated") + dataset_path: str = Field(description="The path to the dataset") + dataset_size_goal: int = Field(description="The goal number of samples to generate") + + +# ------------------------------ Unified UI Class ------------------------------ +class UnifiedRLHFUI: + """Unified RLHF dataset generation interface with persistent GUI.""" + + def __init__( + self, + dataset_path: str, + dataset_size_goal: int, + agent_class: Optional[Type[Agent]] = None, + grammar_rules: Optional[Dict[str, Any]] = None, + environment_class: Optional[Type[Any]] = None, + environment_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the unified RLHF UI with the given parameters. + + Args: + dataset_path: Path to the dataset file. + dataset_size_goal: Target number of datapoints to generate. + agent_class: The agent class to use for behavior tree validation. + grammar_rules: Custom grammar rules for behavior tree validation. + environment_class: Custom environment class for simulation. + environment_kwargs: Arguments to pass to the environment constructor. + config_kwargs: Configuration arguments for the agent. + """ + self.dataset_path = dataset_path + self.dataset_size_goal = dataset_size_goal + self.agent_class = agent_class or RobotAgent + self.grammar_rules = grammar_rules or self._get_default_grammar_rules() + self.environment_class = environment_class or RobotEnvironment + default_env_kwargs = {"headless": True} + if environment_kwargs: + default_env_kwargs.update(environment_kwargs) + self.environment_kwargs = default_env_kwargs + self.config_kwargs = config_kwargs or {} + + self.current_state: Optional[GraphState] = None + self.workflow_running = False + self._feedback_mode: bool = False + + # Create main window + self.root = tk.Tk() + agent_name = self.agent_class.__name__ if self.agent_class else "Unknown Agent" + self.root.title(f"RLHF Dataset Generation - {agent_name}") + self.root.geometry("1000x700") + self.root.resizable(True, True) + + # Initialize workflow components + self.setup_workflow() + self.create_ui() + self.reset_workflow() + + def _get_default_grammar_rules(self) -> Dict[str, Any]: + """Get default grammar rules for behavior tree validation. + + Returns: + Dictionary containing default grammar rules for behavior tree validation. + """ + return { + "B": [["b", ["SEL"]], ["b", ["SEQ"]]], + "SEL": [["sel", ["SEQn", "As"]], ["sel", ["SEQn"]]], + "SEQn": [["SEQ", "SEQn"], ["SEQ"]], + "SEQ": [["seq", ["Pn", "A"]], ["seq", ["As", "Pn", "A"]]], + "b": ["BehaviorTree", ["children_nodes"]], + "sel": ["Selector", ["children_nodes"]], + "seq": ["Sequence", ["children_nodes"]], + "A": [["aa", "sa"], ["aa"], ["sa"]], + "As": [["aa"], ["sa"]], + "aa": ["ActuatorAction"], + "sa": ["StateAction"], + "Pn": [["p", "Pn"], ["p"], []], + "p": ["Condition"], + } + + def setup_workflow(self) -> None: + """Initialize the workflow graph and LLM components.""" + self.prompt_builder = PromptBuilder(self.agent_class) + self.system_prompt = self.prompt_builder.build_system_prompt() + + self.tree_generator_prompt = ChatPromptTemplate.from_messages( + [("system", self.system_prompt), ("placeholder", "{user_prompt}")] + ) + + self.tree_generator_llm = ChatOpenAI(model="gpt-4o", temperature=0) + + # Initialize agent doc parser for extracting node information + if self.agent_class: + self.agent_doc_parser: Optional[AgentDocstringParser] = ( + AgentDocstringParser(self.agent_class) + ) + self.agent_config: Optional[Dict[str, Any]] = ( + self.agent_doc_parser.extract_docstring_config() + ) + else: + self.agent_doc_parser = None + self.agent_config = None + + def create_ui(self) -> None: + """Create the main UI with tabs.""" + # Create notebook for tabs + self.notebook = ttk.Notebook(self.root) + self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Create tabs + self.create_main_tab() + self.create_nodes_tab() + self.create_dataset_tab() + + def create_main_tab(self) -> None: + """Create the main workflow tab.""" + self.main_frame = ttk.Frame(self.notebook) + self.notebook.add(self.main_frame, text="Main Workflow") + + # Create main layout + main_container = ttk.Frame(self.main_frame) + main_container.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Left side - Controls and inputs + left_frame = ttk.Frame(main_container) + left_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=False, padx=(0, 10)) + left_frame.config(width=400) + + # Right side - Information display + right_frame = ttk.Frame(main_container) + right_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True) + + self.create_left_panel(left_frame) + self.create_right_panel(right_frame) + + def create_left_panel(self, parent: tk.Widget) -> tk.Widget: + """Create the left control panel. + + Args: + parent: The parent widget to attach the panel to. + + Returns: + The created left panel widget. + """ + # Dataset info + info_frame = ttk.LabelFrame(parent, text="Dataset Information", padding=10) + info_frame.pack(fill=tk.X, pady=(0, 10)) + + self.dataset_size_label = ttk.Label( + info_frame, text="Current Size: 0", font=("Arial", 10) + ) + self.dataset_size_label.pack(anchor=tk.W) + + self.dataset_goal_label = ttk.Label( + info_frame, + text=f"Target Size: {self.dataset_size_goal}", + font=("Arial", 10), + ) + self.dataset_goal_label.pack(anchor=tk.W) + + self.progress_label = ttk.Label( + info_frame, text="Progress: 0.0%", font=("Arial", 10, "bold") + ) + self.progress_label.pack(anchor=tk.W) + + # Task input + task_frame = ttk.LabelFrame(parent, text="Task Definition", padding=10) + task_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10)) + + ttk.Label( + task_frame, text="Task Description:", font=("Arial", 10, "bold") + ).pack(anchor=tk.W) + self.task_entry = tk.Text(task_frame, height=6, wrap=tk.WORD) + self.task_entry.pack(fill=tk.BOTH, expand=True, pady=(5, 10)) + + ttk.Label(task_frame, text="Metrics Goal:", font=("Arial", 10, "bold")).pack( + anchor=tk.W + ) + self.metrics_entry = tk.Text(task_frame, height=4, wrap=tk.WORD) + self.metrics_entry.pack(fill=tk.BOTH, expand=True, pady=(5, 10)) + + # Feedback input (initially hidden) + self.feedback_frame = ttk.LabelFrame(parent, text="Human Feedback", padding=10) + ttk.Label( + self.feedback_frame, + text="Feedback for improvement:", + font=("Arial", 10, "bold"), + ).pack(anchor=tk.W) + self.feedback_entry = tk.Text(self.feedback_frame, height=4, wrap=tk.WORD) + self.feedback_entry.pack(fill=tk.BOTH, expand=True, pady=(5, 0)) + + # Buttons frame + button_frame = ttk.Frame(parent) + button_frame.pack(fill=tk.X, pady=10) + + self.prompt_button = ttk.Button( + button_frame, text="Generate Tree", command=self.start_generation + ) + self.prompt_button.pack(side=tk.LEFT, padx=(0, 5)) + + self.run_sim_button = ttk.Button( + button_frame, + text="Run Simulation", + command=self.run_simulation, + state=tk.DISABLED, + ) + self.run_sim_button.pack(side=tk.LEFT, padx=5) + + self.feedback_button = ttk.Button( + button_frame, + text="Give Feedback", + command=self.give_feedback, + state=tk.DISABLED, + ) + self.feedback_button.pack(side=tk.LEFT, padx=5) + + self.save_button = ttk.Button( + button_frame, + text="Save Datapoint", + command=self.save_datapoint, + state=tk.DISABLED, + ) + self.save_button.pack(side=tk.LEFT, padx=5) + + return parent + + def create_right_panel(self, parent: tk.Widget) -> tk.Widget: + """Create the right information display panel. + + Args: + parent: The parent widget to attach the panel to. + + Returns: + The created right panel widget. + """ + # Status display + status_frame = ttk.LabelFrame(parent, text="Status", padding=10) + status_frame.pack(fill=tk.X, pady=(0, 10)) + + self.status_label = ttk.Label( + status_frame, + text="Ready to start", + font=("Arial", 12, "bold"), + foreground="blue", + ) + self.status_label.pack() + + # Generated tree display + tree_frame = ttk.LabelFrame(parent, text="Generated Behavior Tree", padding=10) + tree_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10)) + + self.tree_display = scrolledtext.ScrolledText( + tree_frame, height=12, wrap=tk.WORD, state=tk.DISABLED + ) + self.tree_display.pack(fill=tk.BOTH, expand=True) + + # Metrics/Feedback display + metrics_frame = ttk.LabelFrame(parent, text="Validation & Metrics", padding=10) + metrics_frame.pack(fill=tk.BOTH, expand=True) + + self.metrics_display = scrolledtext.ScrolledText( + metrics_frame, height=8, wrap=tk.WORD, state=tk.DISABLED + ) + self.metrics_display.pack(fill=tk.BOTH, expand=True) + + return parent + + def create_nodes_tab(self) -> None: + """Create the nodes information tab.""" + self.nodes_frame = ttk.Frame(self.notebook) + self.notebook.add(self.nodes_frame, text="Available Nodes") + + # Main container + main_container = ttk.Frame(self.nodes_frame) + main_container.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Title + title_frame = ttk.Frame(main_container) + title_frame.pack(fill=tk.X, pady=(0, 10)) + + ttk.Label( + title_frame, text="Available Agent Nodes", font=("Arial", 16, "bold") + ).pack(side=tk.LEFT) + + # Create notebook for different node types + self.nodes_notebook = ttk.Notebook(main_container) + self.nodes_notebook.pack(fill=tk.BOTH, expand=True) + + # Create tabs for each node type + self.create_conditions_tab() + self.create_actuator_actions_tab() + self.create_state_actions_tab() + + def create_conditions_tab(self) -> None: + """Create tab for condition nodes.""" + conditions_frame = ttk.Frame(self.nodes_notebook) + self.nodes_notebook.add(conditions_frame, text="Conditions") + + # Scrollable text area + text_frame = ttk.Frame(conditions_frame) + text_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + self.conditions_display = scrolledtext.ScrolledText( + text_frame, wrap=tk.WORD, state=tk.DISABLED + ) + self.conditions_display.pack(fill=tk.BOTH, expand=True) + + # Populate with condition nodes + self.populate_conditions() + + def create_actuator_actions_tab(self) -> None: + """Create tab for actuator action nodes.""" + actuator_frame = ttk.Frame(self.nodes_notebook) + self.nodes_notebook.add(actuator_frame, text="Actuator Actions") + + # Scrollable text area + text_frame = ttk.Frame(actuator_frame) + text_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + self.actuator_display = scrolledtext.ScrolledText( + text_frame, wrap=tk.WORD, state=tk.DISABLED + ) + self.actuator_display.pack(fill=tk.BOTH, expand=True) + + # Populate with actuator action nodes + self.populate_actuator_actions() + + def create_state_actions_tab(self) -> None: + """Create tab for state action nodes.""" + state_frame = ttk.Frame(self.nodes_notebook) + self.nodes_notebook.add(state_frame, text="State Actions") + + # Scrollable text area + text_frame = ttk.Frame(state_frame) + text_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + self.state_display = scrolledtext.ScrolledText( + text_frame, wrap=tk.WORD, state=tk.DISABLED + ) + self.state_display.pack(fill=tk.BOTH, expand=True) + + # Populate with state action nodes + self.populate_state_actions() + + def populate_conditions(self) -> None: + """Populate conditions tab with available condition nodes.""" + self.conditions_display.config(state=tk.NORMAL) + self.conditions_display.delete(1.0, tk.END) + + # Configure text tags for formatting + self.conditions_display.tag_configure("title", font=("Arial", 14, "bold")) + self.conditions_display.tag_configure("node_name", font=("Arial", 13, "bold")) + self.conditions_display.tag_configure("description", font=("Arial", 10)) + + if self.agent_config is None: + self.conditions_display.insert( + tk.END, "No agent configuration available.", "title" + ) + self.conditions_display.config(state=tk.DISABLED) + return + + conditions = self.agent_config.get("conditions", []) + + if conditions: + # Insert title + self.conditions_display.insert( + tk.END, "Available Condition Nodes:\n", "title" + ) + self.conditions_display.insert(tk.END, "=" * 50 + "\n\n") + + for i, condition in enumerate(conditions, 1): + self.conditions_display.insert( + tk.END, f"{i}. {condition}\n", "node_name" + ) + self.conditions_display.insert(tk.END, "-" * 30 + "\n") + + # Try to get more info about the condition from the agent + try: + if hasattr(self.agent_class, condition): + method = getattr(self.agent_class, condition) + if hasattr(method, "__doc__") and method.__doc__: + self.conditions_display.insert( + tk.END, + f"Description: {method.__doc__.strip()}\n", + "description", + ) + else: + self.conditions_display.insert( + tk.END, + "Description: No documentation available\n", + "description", + ) + else: + self.conditions_display.insert( + tk.END, + "Description: Method not found in agent class\n", + "description", + ) + except Exception as e: + self.conditions_display.insert( + tk.END, + f"Description: Error retrieving info - {str(e)}\n", + "description", + ) + + self.conditions_display.insert(tk.END, "\n") + else: + self.conditions_display.insert(tk.END, "No condition nodes found.", "title") + + self.conditions_display.config(state=tk.DISABLED) + + def populate_actuator_actions(self) -> None: + """Populate actuator actions tab with available actuator action nodes.""" + self.actuator_display.config(state=tk.NORMAL) + self.actuator_display.delete(1.0, tk.END) + + # Configure text tags for formatting + self.actuator_display.tag_configure("title", font=("Arial", 14, "bold")) + self.actuator_display.tag_configure("node_name", font=("Arial", 13, "bold")) + self.actuator_display.tag_configure("description", font=("Arial", 10)) + + if self.agent_config is None: + self.actuator_display.insert( + tk.END, "No agent configuration available.", "title" + ) + self.actuator_display.config(state=tk.DISABLED) + return + + actuator_actions = self.agent_config.get("actuator_actions", []) + + if actuator_actions: + self.actuator_display.insert( + tk.END, "Available Actuator Action Nodes:\n", "title" + ) + self.actuator_display.insert(tk.END, "=" * 50 + "\n\n") + + for i, action in enumerate(actuator_actions, 1): + self.actuator_display.insert(tk.END, f"{i}. {action}\n", "node_name") + self.actuator_display.insert(tk.END, "-" * 30 + "\n") + + # Try to get more info about the action from the agent + try: + if hasattr(self.agent_class, action): + method = getattr(self.agent_class, action) + if hasattr(method, "__doc__") and method.__doc__: + self.actuator_display.insert( + tk.END, + f"Description: {method.__doc__.strip()}\n", + "description", + ) + else: + self.actuator_display.insert( + tk.END, + "Description: No documentation available\n", + "description", + ) + else: + self.actuator_display.insert( + tk.END, + "Description: Method not found in agent class\n", + "description", + ) + except Exception as e: + self.actuator_display.insert( + tk.END, + f"Description: Error retrieving info - {str(e)}\n", + "description", + ) + + self.actuator_display.insert(tk.END, "\n") + else: + self.actuator_display.insert( + tk.END, "No actuator action nodes found.", "title" + ) + + self.actuator_display.config(state=tk.DISABLED) + + def populate_state_actions(self) -> None: + """Populate state actions tab with available state action nodes.""" + self.state_display.config(state=tk.NORMAL) + self.state_display.delete(1.0, tk.END) + + # Configure text tags for formatting + self.state_display.tag_configure("title", font=("Arial", 14, "bold")) + self.state_display.tag_configure("node_name", font=("Arial", 13, "bold")) + self.state_display.tag_configure("description", font=("Arial", 10)) + + if self.agent_config is None: + self.state_display.insert( + tk.END, "No agent configuration available.", "title" + ) + self.state_display.config(state=tk.DISABLED) + return + + state_actions = self.agent_config.get("state_actions", []) + + if state_actions: + self.state_display.insert( + tk.END, "Available State Action Nodes:\n", "title" + ) + self.state_display.insert(tk.END, "=" * 50 + "\n\n") + + for i, action in enumerate(state_actions, 1): + self.state_display.insert(tk.END, f"{i}. {action}\n", "node_name") + self.state_display.insert(tk.END, "-" * 30 + "\n") + + # Try to get more info about the action from the agent + try: + if hasattr(self.agent_class, action): + method = getattr(self.agent_class, action) + if hasattr(method, "__doc__") and method.__doc__: + self.state_display.insert( + tk.END, + f"Description: {method.__doc__.strip()}\n", + "description", + ) + else: + self.state_display.insert( + tk.END, + "Description: No documentation available\n", + "description", + ) + else: + self.state_display.insert( + tk.END, + "Description: Method not found in agent class\n", + "description", + ) + except Exception as e: + self.state_display.insert( + tk.END, + f"Description: Error retrieving info - {str(e)}\n", + "description", + ) + + self.state_display.insert(tk.END, "\n") + else: + self.state_display.insert(tk.END, "No state action nodes found.", "title") + + self.state_display.config(state=tk.DISABLED) + + def create_dataset_tab(self) -> None: + """Create the dataset exploration tab.""" + self.dataset_frame = ttk.Frame(self.notebook) + self.notebook.add(self.dataset_frame, text="Dataset Explorer") + + # Main container + main_container = ttk.Frame(self.dataset_frame) + main_container.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # Title and refresh button + title_frame = ttk.Frame(main_container) + title_frame.pack(fill=tk.X, pady=(0, 10)) + + ttk.Label( + title_frame, text="Dataset Explorer", font=("Arial", 16, "bold") + ).pack(side=tk.LEFT) + self.refresh_button = ttk.Button( + title_frame, text="Refresh", command=self.refresh_dataset_list + ) + self.refresh_button.pack(side=tk.RIGHT) + + # Left side - Datapoint list + left_frame = ttk.LabelFrame(main_container, text="Datapoints", padding=10) + left_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=False, padx=(0, 10)) + left_frame.config(width=300) + + # Listbox with scrollbar for datapoints + list_frame = ttk.Frame(left_frame) + list_frame.pack(fill=tk.BOTH, expand=True) + + self.datapoint_listbox = tk.Listbox(list_frame, width=40) + self.datapoint_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + self.datapoint_listbox.bind("<>", self.on_datapoint_select) + + list_scrollbar = ttk.Scrollbar( + list_frame, orient=tk.VERTICAL, command=self.datapoint_listbox.yview + ) + list_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self.datapoint_listbox.config(yscrollcommand=list_scrollbar.set) + + # Right side - Datapoint details + right_frame = ttk.LabelFrame( + main_container, text="Datapoint Details", padding=10 + ) + right_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True) + + # Layman task display + ttk.Label(right_frame, text="Layman Task:", font=("Arial", 12, "bold")).pack( + anchor=tk.W, pady=(0, 5) + ) + self.layman_display = scrolledtext.ScrolledText( + right_frame, height=6, wrap=tk.WORD, state=tk.DISABLED + ) + self.layman_display.pack(fill=tk.X, pady=(0, 15)) + + # Behavior tree display + ttk.Label(right_frame, text="Behavior Tree:", font=("Arial", 12, "bold")).pack( + anchor=tk.W, pady=(0, 5) + ) + self.tree_details_display = scrolledtext.ScrolledText( + right_frame, height=15, wrap=tk.WORD, state=tk.DISABLED + ) + self.tree_details_display.pack(fill=tk.BOTH, expand=True) + + # Load initial data + self.dataset_data: List[Dict[str, Any]] = [] + self.refresh_dataset_list() + + def reset_workflow(self) -> None: + """Reset the workflow to initial state.""" + current_size = self.get_current_dataset_size() + + self.current_state = GraphState( + task_definition="", + task_metrics_goal="", + dataset_size=current_size, + dataset_path=self.dataset_path, + dataset_size_goal=self.dataset_size_goal, + ) + self.workflow_running = False + + # Clear input fields for fresh start + self.task_entry.delete(1.0, tk.END) + self.metrics_entry.delete(1.0, tk.END) + if hasattr(self, "feedback_entry"): + self.feedback_entry.delete(1.0, tk.END) + + self.update_dataset_info() + self.update_status("Ready to start", "blue") + self.clear_displays() + self.reset_buttons() + + def get_current_dataset_size(self) -> int: + """Get current dataset size from file. + + Returns: + The current number of datapoints in the dataset file. + """ + import json + + if os.path.exists(self.dataset_path): + try: + with open(self.dataset_path, "r") as f: + dataset = json.load(f) + return len(dataset) + except (json.JSONDecodeError, FileNotFoundError): + return 0 + return 0 + + def update_dataset_info(self) -> None: + """Update dataset information display.""" + current_size = self.get_current_dataset_size() + progress = ( + (current_size / self.dataset_size_goal) * 100 + if self.dataset_size_goal > 0 + else 0 + ) + + self.dataset_size_label.config(text=f"Current Size: {current_size}") + self.progress_label.config(text=f"Progress: {progress:.1f}%") + + def update_status(self, message: str, color: str = "black") -> None: + """Update status message. + + Args: + message: The status message to display. + color: The color of the status message. + """ + self.status_label.config(text=message, foreground=color) + self.root.update() + + def clear_displays(self) -> None: + """Clear all display areas.""" + self.tree_display.config(state=tk.NORMAL) + self.tree_display.delete(1.0, tk.END) + self.tree_display.config(state=tk.DISABLED) + + self.metrics_display.config(state=tk.NORMAL) + self.metrics_display.delete(1.0, tk.END) + self.metrics_display.config(state=tk.DISABLED) + + def reset_buttons(self) -> None: + """Reset button states.""" + self.prompt_button.config(state=tk.NORMAL) + self.run_sim_button.config(state=tk.DISABLED) + self.feedback_button.config(state=tk.NORMAL) + self.save_button.config(state=tk.NORMAL) + self.feedback_frame.pack_forget() + + def start_generation(self) -> None: + """Start the tree generation process.""" + if self.workflow_running: + return + + # Get input values + task_definition = self.task_entry.get(1.0, tk.END).strip() + task_metrics_goal = self.metrics_entry.get(1.0, tk.END).strip() + + if not task_definition: + self.update_status("Please fill in the task definition", "red") + return + + # Update state + if self.current_state: + self.current_state.task_definition = task_definition + self.current_state.task_metrics_goal = task_metrics_goal + + # Get feedback if available + if ( + hasattr(self, "_feedback_mode") + and self._feedback_mode + and self.current_state + ): + human_feedback = self.feedback_entry.get(1.0, tk.END).strip() + self.current_state.human_feedback = ( + human_feedback if human_feedback else None + ) + self._feedback_mode = False + self.feedback_frame.pack_forget() + else: + if self.current_state: + self.current_state.human_feedback = None + + # Start generation in separate thread + self.workflow_running = True + self.prompt_button.config(state=tk.DISABLED) + self.update_status("Generating behavior tree...", "orange") + + threading.Thread(target=self.generate_tree, daemon=True).start() + + def generate_tree(self) -> None: + """Generate behavior tree using LLM.""" + try: + if not self.current_state: + self.root.after( + 0, self.on_generation_error, "No current state available" + ) + return + + # Prepare base prompt + prompt = f"Please generate a behaviour tree for the following task: {self.current_state.task_definition}" + + # Add metrics goal if provided + if self.current_state.task_metrics_goal: + prompt += f"\nThe task metrics to achieve are: {self.current_state.task_metrics_goal}" + + # Add feedback section if there's human feedback + if self.current_state.human_feedback: + previous_tree = self.current_state.behaviour_tree or "No previous tree" + prompt += f"\n\nYour previous attempt was:\n{previous_tree}\n\nHuman Feedback on Previous Attempt: {self.current_state.human_feedback}\n\nPlease incorporate this feedback to improve the behavior tree." + + # Generate tree + class TreeGeneratorOutput(BaseModel): + behaviour_tree: str = Field( + description="The raw behaviour tree in XML format without any quotes or markdown formatting" + ) + + tree_gen_chain = ( + self.tree_generator_prompt + | self.tree_generator_llm.with_structured_output(TreeGeneratorOutput) + ) + result = tree_gen_chain.invoke({"user_prompt": [("user", prompt)]}) + + # Update UI in main thread + if hasattr(result, "behaviour_tree"): + self.root.after(0, self.on_tree_generated, result.behaviour_tree) + elif isinstance(result, dict) and "behaviour_tree" in result: + self.root.after(0, self.on_tree_generated, result["behaviour_tree"]) + else: + self.root.after( + 0, + self.on_generation_error, + f"Unexpected result format: {type(result)}", + ) + + except Exception as e: + self.root.after(0, self.on_generation_error, str(e)) + + def on_tree_generated(self, behaviour_tree: str) -> None: + """Handle successful tree generation. + + Args: + behaviour_tree: The generated behavior tree string. + """ + if self.current_state: + self.current_state.behaviour_tree = behaviour_tree + + # Display tree + self.tree_display.config(state=tk.NORMAL) + self.tree_display.delete(1.0, tk.END) + self.tree_display.insert(1.0, behaviour_tree) + self.tree_display.config(state=tk.DISABLED) + + self.update_status("Tree generated! Validating...", "orange") + + # Validate tree + threading.Thread(target=self.validate_tree, daemon=True).start() + + def on_generation_error(self, error_msg: str) -> None: + """Handle tree generation errors. + + Args: + error_msg: The error message to display. + """ + self.update_status(f"Generation failed: {error_msg}", "red") + + # Reset workflow state + self.workflow_running = False + self.prompt_button.config(state=tk.NORMAL) + + def validate_tree(self) -> None: + """Validate the generated behavior tree.""" + if not self.current_state or not self.current_state.behaviour_tree: + self.update_status("No tree to validate", "red") + return + + self.update_status("Validating tree...", "orange") + + threading.Thread( + target=lambda: self._validate_tree_thread(), daemon=True + ).start() + + def _validate_tree_thread(self) -> None: + """Run tree validation in background thread.""" + try: + if not self.current_state or not self.current_state.behaviour_tree: + self.root.after( + 0, self.on_validation_error, "No tree available for validation" + ) + return + + grammar_validator = BehaviorTreeGrammarValidator(self.grammar_rules) + passed_grammar, grammar_feedback = grammar_validator.validate_tree( + self.current_state.behaviour_tree + ) + passed_primitive, primitive_feedback = validate_primitives( + self.current_state.behaviour_tree, self.agent_class + ) + + passed = passed_grammar and passed_primitive + feedback = "" + if not passed_grammar: + feedback += f"Grammar validation failed: {grammar_feedback}\n" + if not passed_primitive: + feedback += f"Primitive validation failed: {primitive_feedback}" + + self.root.after(0, self.on_validation_complete, passed, feedback) + except Exception as e: + self.root.after(0, self.on_validation_error, str(e)) + + def on_validation_complete(self, passed: bool, feedback: str) -> None: + """Handle validation completion. + + Args: + passed: Whether the validation passed. + feedback: The validation feedback message. + """ + if self.current_state: + self.current_state.passed_validator = passed + self.current_state.validator_feedback = feedback + + if passed: + self.update_status("Tree validation passed!", "green") + if self.current_state: + self.metrics_display.config(state=tk.NORMAL) + self.metrics_display.delete(1.0, tk.END) + self.metrics_display.insert( + 1.0, f"Desired Metrics:\n{self.current_state.task_metrics_goal}" + ) + self.metrics_display.config(state=tk.DISABLED) + else: + self.update_status("Tree validation failed!", "red") + self.metrics_display.config(state=tk.NORMAL) + self.metrics_display.delete(1.0, tk.END) + self.metrics_display.insert( + 1.0, + f"Validation Feedback:\n{feedback}\n\nSimulation cannot run with invalid tree.", + ) + self.metrics_display.config(state=tk.DISABLED) + + # Enable buttons based on validation result + self.run_sim_button.config(state=tk.NORMAL if passed else tk.DISABLED) + self.feedback_button.config(state=tk.NORMAL) + self.save_button.config(state=tk.NORMAL) + + # Reset workflow state + self.workflow_running = False + self.prompt_button.config(state=tk.NORMAL) + + def on_validation_error(self, error_msg: str) -> None: + """Handle validation errors. + + Args: + error_msg: The error message to display. + """ + self.update_status(f"Validation failed: {error_msg}", "red") + + # Reset workflow state + self.workflow_running = False + self.prompt_button.config(state=tk.NORMAL) + + def run_simulation(self) -> None: + """Run simulation of the behavior tree.""" + if not self.current_state or not self.current_state.passed_validator: + self.update_status("Cannot run simulation: tree validation failed", "red") + return + + self.update_status("Running simulation...", "orange") + self.run_sim_button.config(state=tk.DISABLED) + + threading.Thread(target=self.run_sim_thread, daemon=True).start() + + def run_sim_thread(self) -> None: + """Run simulation in background thread. + + Raises: + ValueError: If there is an error during simulation execution. + """ + try: + if not self.current_state or not self.current_state.behaviour_tree: + raise ValueError("No valid tree available for simulation") + + metrics_result = run_robot_sim( + self.current_state.behaviour_tree, + environment_class=self.environment_class, + environment_kwargs=self.environment_kwargs, + config_kwargs=self.config_kwargs, + ) + + self.root.after(0, self.on_simulation_complete, metrics_result) + except Exception as e: + self.root.after(0, self.on_simulation_error, str(e)) + + def on_simulation_complete(self, metrics_result: Dict[str, Any]) -> None: + """Handle simulation completion. + + Args: + metrics_result: The simulation metrics and results. + """ + if self.current_state: + self.current_state.task_metrics_result = metrics_result + + # Update metrics display + self.metrics_display.config(state=tk.NORMAL) + self.metrics_display.delete(1.0, tk.END) + + if self.current_state: + content = f"Desired Metrics:\n{self.current_state.task_metrics_goal}\n\n" + content += f"Achieved Metrics:\n{metrics_result}" + self.metrics_display.insert(1.0, content) + + self.metrics_display.config(state=tk.DISABLED) + self.update_status("Simulation completed!", "green") + + # Re-enable the simulation button + self.run_sim_button.config(state=tk.NORMAL) + + # Force UI update to ensure responsiveness + self.root.update_idletasks() + self.root.update() + + def on_simulation_error(self, error_msg: str) -> None: + """Handle simulation errors. + + Args: + error_msg: The error message to display. + """ + self.update_status(f"Simulation failed: {error_msg}", "red") + + # Re-enable the simulation button on error + self.run_sim_button.config(state=tk.NORMAL) + + # Force UI update to ensure responsiveness + self.root.update_idletasks() + self.root.update() + + def give_feedback(self) -> None: + """Enable feedback mode for human input.""" + if not self.current_state or not self.current_state.behaviour_tree: + self.update_status("No tree available for feedback", "red") + return + + self._feedback_mode = True + + # Show and setup the feedback frame + self.feedback_frame.pack(fill=tk.X, pady=(0, 10)) + self.feedback_entry.delete(1.0, tk.END) + self.feedback_entry.focus() + + self.update_status( + "Provide feedback and click 'Generate Tree' to retry", "blue" + ) + + def save_datapoint(self) -> None: + """Save current tree and metrics as a datapoint.""" + if not self.current_state or not self.current_state.behaviour_tree: + self.update_status("No tree to save", "red") + return + + try: + save_datapoint( + dataset_path=self.current_state.dataset_path, + task_description=self.current_state.task_definition, + tree_str=self.current_state.behaviour_tree, + agent_class=self.agent_class, + ) + + self.update_dataset_info() + self.update_status("Datapoint saved successfully!", "green") + + # Check if goal is reached + if self.ask_continue_generation(): + self.reset_workflow() + else: + self.update_status("Dataset generation completed!", "blue") + except Exception as e: + self.update_status(f"Failed to save datapoint: {e}", "red") + + def ask_continue_generation(self) -> bool: + """Ask user if they want to continue generating data. + + Returns: + True if the user wants to continue, False otherwise. + """ + if not self.current_state: + return False + + current_size = self.get_current_dataset_size() + + if current_size >= self.dataset_size_goal: + import tkinter.messagebox as msgbox + + result = msgbox.askyesno( + "Goal Reached", + f"Dataset size goal of {self.dataset_size_goal} reached! " + f"Current size: {current_size}. Continue generating?", + ) + else: + import tkinter.messagebox as msgbox + + result = msgbox.askyesno( + "Continue Generation", + f"Datapoint saved! ({current_size}/{self.dataset_size_goal})\nWould you like to continue generating more datapoints?", + ) + + if result: + self.reset_workflow() + else: + self.update_status("Generation complete", "blue") + + # Refresh dataset explorer + self.refresh_dataset_list() + + return result + + def refresh_dataset_list(self) -> None: + """Refresh the dataset list in the dataset tab.""" + if not hasattr(self, "datapoint_listbox"): + return + + self.datapoint_listbox.delete(0, tk.END) + + import json + + try: + if os.path.exists(self.dataset_path): + with open(self.dataset_path, "r") as f: + dataset = json.load(f) + for i, datapoint in enumerate(dataset): + preview = datapoint.get("layman_task", "No task description")[ + :50 + ] + if len(preview) == 50: + preview += "..." + self.datapoint_listbox.insert(tk.END, f"{i+1}. {preview}") + except (json.JSONDecodeError, FileNotFoundError): + pass + + self.clear_dataset_details() + + def on_datapoint_select(self, event: Any) -> None: + """Handle datapoint selection in the list. + + Args: + event: The selection event from the listbox. + """ + try: + selection = self.datapoint_listbox.curselection() # type: ignore[no-untyped-call] + if not selection: + return + + index = selection[0] + import json + + with open(self.dataset_path, "r") as f: + dataset = json.load(f) + if index < len(dataset): + self.display_datapoint_details(dataset[index]) + except Exception: + pass + + def display_datapoint_details(self, datapoint: Dict[str, Any]) -> None: + """Display details of selected datapoint. + + Args: + datapoint: The datapoint dictionary containing task and tree information. + """ + # Display layman task + self.layman_display.config(state=tk.NORMAL) + self.layman_display.delete(1.0, tk.END) + self.layman_display.insert( + 1.0, datapoint.get("layman_task", "No task description") + ) + self.layman_display.config(state=tk.DISABLED) + + # Display formatted tree + self.tree_details_display.config(state=tk.NORMAL) + self.tree_details_display.delete(1.0, tk.END) + + tree_content = datapoint.get("tree", "No tree available") + if tree_content: + # Format XML for better readability + try: + import xml.dom.minidom + + dom = xml.dom.minidom.parseString(tree_content) + formatted_tree = dom.toprettyxml(indent=" ") + # Remove empty lines + formatted_tree = "\n".join( + [line for line in formatted_tree.split("\n") if line.strip()] + ) + self.tree_details_display.insert(1.0, formatted_tree) + except Exception: + self.tree_details_display.insert(1.0, tree_content) + + self.tree_details_display.config(state=tk.DISABLED) + + def clear_dataset_details(self) -> None: + """Clear the dataset details display.""" + if hasattr(self, "layman_display"): + self.layman_display.config(state=tk.NORMAL) + self.layman_display.delete(1.0, tk.END) + self.layman_display.config(state=tk.DISABLED) + + if hasattr(self, "tree_details_display"): + self.tree_details_display.config(state=tk.NORMAL) + self.tree_details_display.delete(1.0, tk.END) + self.tree_details_display.config(state=tk.DISABLED) + + def run(self) -> None: + """Start the application.""" + self.root.mainloop() diff --git a/data_grammar/rlhf_generation/utils/prompt_builder.py b/data_grammar/rlhf_generation/utils/prompt_builder.py new file mode 100644 index 0000000..0ed4ef4 --- /dev/null +++ b/data_grammar/rlhf_generation/utils/prompt_builder.py @@ -0,0 +1,104 @@ +"""Class to build system prompt for the RLHF agent.""" + +import os +from typing import Dict, Type + +from vi import Agent + + +class PromptBuilder: + """Builds system prompts for RLHF agents using XML templates and agent class behaviors.""" + + def __init__(self, agent_class: Type[Agent]) -> None: + """ + Initialize the PromptBuilder with an agent class. + + Args: + agent_class: The agent class to extract behaviors from + """ + self.agent_class = agent_class + + def load_xml_template(self, file_name: str) -> str: + """ + Load an XML template from the prompt_techniques directory. + + Args: + file_name: The name of the template file + + Returns: + The contents of the template file + """ + directory = "llm_interface/prompt_techniques" + file_path = os.path.join(directory, file_name) + with open(file_path, "r") as file: + return file.read() + + def extract_agent_behaviors(self) -> Dict[str, str]: + """ + Get all the behaviors from the agent class, extracting only Node Type and Description. + + Returns: + Dictionary of function names and their processed docstrings (Node Type and Description only) + """ + class_obj = self.agent_class + function_names_and_docstrings = {} + + for func_name in dir(class_obj): + if ( + callable(getattr(class_obj, func_name)) + and not func_name.startswith("__") + and not func_name.startswith("update") + and not func_name.startswith("helper") + and getattr(class_obj, func_name).__qualname__.startswith( + class_obj.__name__ + "." + ) + ): + func = getattr(class_obj, func_name) + if func.__doc__: + # Split docstring into lines and process + doc_lines = func.__doc__.strip().split("\n") + processed_doc = [] + + for line in doc_lines: + line = line.strip() + if line.startswith("Node Type:") or line.startswith( + "Description:" + ): + processed_doc.append(line) + + # Join the processed lines back together + function_names_and_docstrings[func_name] = "\n".join(processed_doc) + else: + function_names_and_docstrings[func_name] = "No docstring found." + + return function_names_and_docstrings + + def build_system_prompt(self) -> str: + """ + Build the complete system prompt by loading the XML template and replacing placeholders. + + Returns: + The complete system prompt with all placeholders replaced + """ + # Load and prepare system prompt + system_prompt = self.load_xml_template("system_prompt.xml") + bt_3_example = self.load_xml_template("bt_3.xml") + behaviors = self.extract_agent_behaviors() + + # Format behaviors dictionary to avoid template variable conflicts + # Convert to a safe string representation and escape all curly braces + behaviors_str = "{{\n" # Start with escaped curly brace + for key, value in behaviors.items(): + # Escape any curly braces in the content and format safely + safe_value = value.replace("{", "{{").replace("}", "}}") + safe_key = key.replace("{", "{{").replace("}", "}}") + behaviors_str += f' "{safe_key}": "{safe_value}",\n' + behaviors_str = ( + behaviors_str.rstrip(",\n") + "\n}}" + ) # End with escaped curly brace + + # Replace placeholders in system prompt + system_prompt = system_prompt.replace("{bt_3.xml}", bt_3_example) + system_prompt = system_prompt.replace("{BEHAVIORS}", behaviors_str) + + return system_prompt diff --git a/data_grammar/rlhf_generation/utils/run_robot_sim.py b/data_grammar/rlhf_generation/utils/run_robot_sim.py new file mode 100644 index 0000000..a7ce216 --- /dev/null +++ b/data_grammar/rlhf_generation/utils/run_robot_sim.py @@ -0,0 +1,99 @@ +"""Run robot simulation with behavior tree and return metrics.""" + +import os +from typing import Any, Dict, Optional, Type + +from vi import Config, Window + +from agent_control import RobotEnvironment +from tree_parser import save_behavior_tree_xml + + +def run_robot_sim( + bt_str: str, + environment_class: Optional[Type[Any]] = None, + environment_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Run a robot environment simulation with behavior trees. + + Args: + bt_str: The behavior tree XML string + environment_class: The environment class to use (defaults to RobotEnvironment) + environment_kwargs: Dict of keyword arguments for the environment constructor + Example: {'n_agents': 10, 'n_parts': 15, 'task': 'collect', 'headless': False} + config_kwargs: Dict of keyword arguments for the Config constructor + Example: {'radius': 25, 'duration': 1000, 'window_size': 500, 'movement_speed': 1.0} + + Returns: + Dict containing simulation metrics + """ + # Use defaults if not provided + environment_class = environment_class or RobotEnvironment + environment_kwargs = environment_kwargs or {} + config_kwargs = config_kwargs or {} + + # Set default config values if not provided + default_config = { + "radius": 25, # Agent detection radius + "visualise_chunks": True, # Show spatial partitioning + "window_size": 500, # Window size (will create square window) + "movement_speed": 1.0, # Agent movement speed + "duration": 1000, # Simulation duration in steps + } + + # Merge default config with provided config + final_config = {**default_config, **config_kwargs} + + # Handle window_size parameter - convert to Window object + if "window_size" in final_config: + window_size = final_config.pop("window_size") + final_config["window"] = Window.square(window_size) + elif "window" not in final_config: + # If neither window_size nor window is provided, use default + final_config["window"] = Window.square(500) + + # Configuration for the simulation + config = Config(**final_config) + + # Create a temp path for the behavior tree + bt_path = "./temp_bt.xml" + save_behavior_tree_xml(bt_str, bt_path) + + # Set default environment values if not provided + default_environment = { + "config": config, + "bt_path": bt_path, + "n_agents": 10, # Number of robot agents + "n_parts": 10, # Number of parts in environment + "task": "collect", # Task description + "headless": False, # Show GUI (set to True for no GUI) + } + + # Merge default environment with provided environment kwargs + # Note: provided kwargs override defaults, including config and bt_path + final_environment = {**default_environment, **environment_kwargs} + final_environment["config"] = config # Always use the constructed config + final_environment["bt_path"] = bt_path # Always use the temp bt_path + + # Create the environment + environment = environment_class(**final_environment) + environment.setup() + + try: + # Run the simulation and collect metrics + results = environment.run() + finally: + # Ensure pygame is properly cleaned up + try: + import pygame + + pygame.display.quit() + pygame.quit() + except: + pass # Ignore errors if pygame is not available or already cleaned up + + # Delete the temp behavior tree + os.remove(bt_path) + + return results diff --git a/data_grammar/rlhf_generation/utils/save_data_point.py b/data_grammar/rlhf_generation/utils/save_data_point.py new file mode 100644 index 0000000..86c7592 --- /dev/null +++ b/data_grammar/rlhf_generation/utils/save_data_point.py @@ -0,0 +1,87 @@ +"""Save a datapoint in the given path after producing other prompt styles.""" + +import json +import os +from typing import Any, Type + +from data_grammar.grammar_gen.node_translations import node_connectors +from data_grammar.grammar_gen.tree_to_prompt import ( + generate_spoon_prompt_from_string, + generate_technical_prompt_from_string, +) +from tree_parser.agent_doc_parser import AgentDocstringParser + + +def save_datapoint( + dataset_path: str, task_description: str, tree_str: str, agent_class: Type[Any] +) -> None: + """ + Save a single datapoint to the dataset file after generating other prompt styles. + + Args: + dataset_path: Path to the JSON dataset file (should end with .json) + task_description: The layman task description + tree_str: The behavior tree XML string + agent_class: The agent class to extract translations from + """ + # Ensure the dataset path has .json extension + if not dataset_path.endswith(".json"): + dataset_path = dataset_path + ".json" + + # Retrieve the translations + doc_parser = AgentDocstringParser(agent_class=agent_class) + extracted_info_dict = doc_parser.extract_docstring_config() + spoon_translations = extracted_info_dict["spoon_node_translations"] + technical_translations = extracted_info_dict["node_translations"] + + # Produce the prompts + layman_prompt = task_description + try: + spoon_prompt = generate_spoon_prompt_from_string( + tree_string=tree_str, + spoon_translations=spoon_translations, + connectors=node_connectors, + ) + except Exception as e: + spoon_prompt = f"Error generating spoon task: {e}" + + try: + tech_prompt = generate_technical_prompt_from_string( + tree_string=tree_str, + translations=technical_translations, + connectors=node_connectors, + ) + except Exception as e: + tech_prompt = f"Error generating technical task: {e}" + + # Create the datapoint in the same format as the other generators + datapoint = { + "layman_task": layman_prompt, + "technical_task": tech_prompt, + "spoon_task": spoon_prompt, + "tree": tree_str, + } + + # Load existing dataset or create new one + dataset = [] + if os.path.exists(dataset_path): + try: + with open(dataset_path, "r") as json_file: + dataset = json.load(json_file) + except (json.JSONDecodeError, FileNotFoundError): + # If file exists but is corrupted or empty, start fresh + dataset = [] + + # Append the new datapoint + dataset.append(datapoint) + + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(dataset_path), exist_ok=True) + + # Save the updated dataset + with open(dataset_path, "w") as json_file: + json.dump(dataset, json_file, indent=4) + + print( + f"Datapoint saved to {dataset_path}. Dataset now contains {len(dataset)} samples." + ) diff --git a/examples/03_data_gen_examples/rlhf_dataset_gen.py b/examples/03_data_gen_examples/rlhf_dataset_gen.py new file mode 100644 index 0000000..a8fcb19 --- /dev/null +++ b/examples/03_data_gen_examples/rlhf_dataset_gen.py @@ -0,0 +1,61 @@ +""" +Example showing how to use the UnifiedRLHFUI class to generate +behavior trees and datasets through a collaboration between LLM and human. +""" +from agent_control import RobotAgent, RobotEnvironment +from data_grammar.rlhf_generation.rlhf_unified import UnifiedRLHFUI + +# Step 1: Define grammar rules for behavior tree validation +grammar_rules = { + "B": [["b", ["SEL"]], ["b", ["SEQ"]]], + "SEL": [["sel", ["SEQn", "As"]], ["sel", ["SEQn"]]], + "SEQn": [["SEQ", "SEQn"], ["SEQ"]], + "SEQ": [["seq", ["Pn", "A"]], ["seq", ["As", "Pn", "A"]]], + "b": ["BehaviorTree", ["children_nodes"]], + "sel": ["Selector", ["children_nodes"]], + "seq": ["Sequence", ["children_nodes"]], + "A": [["aa", "sa"], ["aa"], ["sa"]], + "As": [["aa"], ["sa"]], + "aa": ["ActuatorAction"], + "sa": ["StateAction"], + "Pn": [["p", "Pn"], ["p"], []], + "p": ["Condition"], +} + +# Step 2: Configure simulation parameters +config_kwargs = { + "radius": 40, + "duration": 500, + "movement_speed": 1.0, + "window_size": 800, + "visualise_chunks": True, +} + +# Step 3: Configure environment parameters +environment_kwargs = { + "n_agents": 10, + "n_parts": 10, + "task": "custom_task", + "headless": False, +} + +# Step 4: Initialize the RLHF dataset generation UI +# This creates a unified interface for interactive behavior tree generation +ui = UnifiedRLHFUI( + dataset_path="./examples/03_data_gen_examples/output/dataset_path.json", # Output dataset file + dataset_size_goal=10, # Target number of datapoints + agent_class=RobotAgent, # Agent class for validation + grammar_rules=grammar_rules, # Grammar rules defined above + environment_class=RobotEnvironment, # Environment class for simulation + environment_kwargs=environment_kwargs, # Environment parameters + config_kwargs=config_kwargs, # Simulation configuration +) + +# Step 5: Start the interactive RLHF dataset generation process +# - Define tasks in natural language +# - Generate behavior trees using LLM +# - Validate trees against grammar and primitives +# - Run simulations to test behavior +# - Provide human feedback for iterative improvement +# - Save successful datapoints to the dataset +ui.run() \ No newline at end of file diff --git a/llm_interface/layer_LLM.py b/llm_interface/layer_LLM.py index b5842d0..f16e277 100644 --- a/llm_interface/layer_LLM.py +++ b/llm_interface/layer_LLM.py @@ -21,7 +21,7 @@ except ImportError: OLLAMA_AVAILABLE = False -from dotenv import load_dotenv # type: ignore +from dotenv import load_dotenv from lancedb import table from lancedb.pydantic import LanceModel from openai import OpenAI @@ -406,7 +406,7 @@ def _auto_import_gguf(self, gguf_path: str) -> str: with open(modelfile_path, "r") as f: modelfile_content = f.read() - ollama.create(model=model_name, modelfile=modelfile_content) + ollama.create(model=model_name, modelfile=modelfile_content) # type: ignore[call-overload] print(f"Successfully imported model as: {model_name}") # Return with :latest suffix since that's how Ollama stores it return f"{model_name}:latest" @@ -414,7 +414,7 @@ def _auto_import_gguf(self, gguf_path: str) -> str: # If API also fails, try one more approach - stream the creation try: print("Trying streaming create...") - for response in ollama.create( + for response in ollama.create( # type: ignore[call-overload] model=model_name, modelfile=modelfile_content, stream=True ): if "status" in response: