diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..f2bccc7
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 Andres Garcia
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 1100467..b417b87 100644
--- a/README.md
+++ b/README.md
@@ -65,6 +65,7 @@ Explore synthetic dataset generation using formal grammars:
- **`custom_grammar_gen.py`** - Generate datasets using custom grammar configurations
- **`filtered_dataset_generation.py`** - Advanced dataset generation with simulation-based filtering
- **`mixed_gen.py`** - Mixed dataset generation combining multiple approaches
+- **`rlhf_dataset_gen.py`** - Interactive RLHF-inspired dataset generation with human-in-the-loop feedback using a unified GUI
#### **Experimental Setup Examples** (`examples/04_experimental_setup_example/`)
Complete experimental pipeline example for research and evaluation:
diff --git a/agent_control/simulation/agents/agent_sensors.py b/agent_control/simulation/agents/agent_sensors.py
index 8036784..987a759 100644
--- a/agent_control/simulation/agents/agent_sensors.py
+++ b/agent_control/simulation/agents/agent_sensors.py
@@ -25,7 +25,7 @@ def sense_light(self) -> Vector2:
Returns:
Vector2: Direction vector to the nearest light source
"""
- light_pos = self.agent.env.light_pos
+ light_pos = self.agent.env.light_pos # type: ignore[attr-defined]
agent_pos = self.agent.pos
diff_vec: Vector2 = light_pos - agent_pos
diff --git a/agent_control/simulation/agents/elements.py b/agent_control/simulation/agents/elements.py
index 252cc6d..ba9c199 100644
--- a/agent_control/simulation/agents/elements.py
+++ b/agent_control/simulation/agents/elements.py
@@ -11,7 +11,7 @@
from ..envs.robot_env import RobotEnvironment
-class Part(Agent): # type: ignore
+class Part(Agent):
"""Part object that can be picked up and moved by robot agents."""
def __init__(
@@ -38,7 +38,7 @@ def __init__(
self.simulation = simulation
self.type = type
self.pos = pos
- self.owner: Optional[Agent] = None
+ self.owner: Optional[Agent | "RobotEnvironment"] = None
self.env = env
self.is_permanently_placed = False
self.update_img()
@@ -50,11 +50,11 @@ def update_img(self) -> None:
def remove_part(self) -> None:
"""Remove this part from the simulation."""
- self.kill()
+ self.kill() # type: ignore[no-untyped-call]
def update(self) -> None:
"""Update the part's position and state."""
- if self.owner and not self.is_permanently_placed:
+ if self.owner and not self.is_permanently_placed and hasattr(self.owner, "pos"):
self.pos = self.owner.pos
def can_be_picked_up(self) -> bool:
diff --git a/agent_control/simulation/agents/robot_agent.py b/agent_control/simulation/agents/robot_agent.py
index b69db5e..467855a 100644
--- a/agent_control/simulation/agents/robot_agent.py
+++ b/agent_control/simulation/agents/robot_agent.py
@@ -5,7 +5,7 @@
import pygame as pg
from pygame.math import Vector2
-from vi import Agent, Simulation
+from vi import Agent, HeadlessSimulation, Simulation
from tree_parser.middle_parser import parse_behavior_tree
@@ -16,13 +16,13 @@
from ..envs.base_env import SimEnvironment
-class RobotAgent(Agent): # type: ignore
+class RobotAgent(Agent):
"""Robot agent that executes behavior trees in simulation environments."""
def __init__(
self,
images: List[pg.Surface],
- simulation: Simulation,
+ simulation: Simulation | HeadlessSimulation,
pos: Vector2,
env: "SimEnvironment",
xml_path: str,
diff --git a/agent_control/simulation/envs/robot_env.py b/agent_control/simulation/envs/robot_env.py
index ec06f28..fae9e6d 100644
--- a/agent_control/simulation/envs/robot_env.py
+++ b/agent_control/simulation/envs/robot_env.py
@@ -88,32 +88,32 @@ def draw_arena(self) -> None:
"""Draw arena boundaries in the simulation environment."""
self.simulation.spawn_obstacle(
"./agent_control/simulation/images/arena_new.png",
- self.arena_pos.x,
- self.arena_pos.y,
+ int(self.arena_pos.x),
+ int(self.arena_pos.y),
)
def draw_source(self) -> None:
"""Draw source area where good parts are located."""
self.simulation.spawn_site(
"./agent_control/simulation/images/source_green.png",
- self.source_pos.x,
- self.source_pos.y,
+ int(self.source_pos.x),
+ int(self.source_pos.y),
)
def draw_nest(self) -> None:
"""Draw base/nest area where parts should be delivered."""
self.simulation.spawn_site(
"./agent_control/simulation/images/blue_nest.png",
- self.base_pos.x,
- self.base_pos.y,
+ int(self.base_pos.x),
+ int(self.base_pos.y),
)
def draw_waste(self) -> None:
"""Draw waste area where bad parts should be disposed."""
self.simulation.spawn_site(
"./agent_control/simulation/images/waste_red.png",
- self.waste_pos.x,
- self.waste_pos.y,
+ int(self.waste_pos.x),
+ int(self.waste_pos.y),
)
def spawn_part(self, type: str, pos: Vector2) -> None:
@@ -141,7 +141,7 @@ def remove_part(self, part: Part) -> None:
Args:
part: Part object to remove from simulation
"""
- part.kill()
+ part.kill() # type: ignore[no-untyped-call]
def place_parts(self, num_parts: int) -> None:
"""
diff --git a/data_grammar/__init__.py b/data_grammar/__init__.py
index dffa68a..b88242d 100644
--- a/data_grammar/__init__.py
+++ b/data_grammar/__init__.py
@@ -1,5 +1,6 @@
"""Data Grammar package for generating behavior trees and datasets."""
from .dataset_generator import DatasetGenerator
+from .rlhf_generation.rlhf_unified import UnifiedRLHFUI
-__all__ = ["DatasetGenerator"]
+__all__ = ["DatasetGenerator", "UnifiedRLHFUI"]
diff --git a/data_grammar/dataset_generation/sys_prompt.py b/data_grammar/dataset_generation/sys_prompt.py
index 99a00f0..6f5efdc 100644
--- a/data_grammar/dataset_generation/sys_prompt.py
+++ b/data_grammar/dataset_generation/sys_prompt.py
@@ -504,7 +504,7 @@
Metrics:
-{"good_parts_picked_up": 1}
+{{"good_parts_picked_up": 1}}
Task: Your task is to find a good part and stop moving.
@@ -528,7 +528,7 @@
Metrics:
-{"good_parts_picked_up": 1}
+{{"good_parts_picked_up": 1}}
Task: Your task is to find a scrap part and stop moving.
@@ -552,7 +552,7 @@
Metrics:
-{"bad_parts_picked_up": 1}
+{{"bad_parts_picked_up": 1}}
-Examples for finding and bringing parts to areas:
@@ -584,7 +584,7 @@
Metrics:
-{"good_parts_picked_up": 1, "parts_dropped_in_construction": [1, 0]}
+{{"good_parts_picked_up": 1, "parts_dropped_in_construction": [1, 0]}}
Task: Your task is to find scrap parts and bring them to the source area.
@@ -614,7 +614,7 @@
Metrics:
-{"bad_parts_picked_up": 1, "parts_dropped_in_source": [0, 1]}
+{{"bad_parts_picked_up": 1, "parts_dropped_in_source": [0, 1]}}
- Complex example on how to handle both good and scrap parts to different areas:
@@ -666,7 +666,7 @@
Metrics:
-{"good_parts_picked_up": 1, "bad_parts_picked_up": 1, "parts_dropped_in_storage": [1, 0], "parts_dropped_in_source": [0, 1]}
+{{"good_parts_picked_up": 1, "bad_parts_picked_up": 1, "parts_dropped_in_storage": [1, 0], "parts_dropped_in_source": [0, 1]}}
You can use the examples above to help you understand how nodes come toguether to build functional units of a behaviour tree, but you can also take direct inspiration to populate a tree if the strucutre is the same or similar to an example.
@@ -722,7 +722,7 @@
Metrics:
-{"bad_parts_picked_up": 1}
+{{"bad_parts_picked_up": 1}}
"""
diff --git a/data_grammar/dataset_generator.py b/data_grammar/dataset_generator.py
index 839cb5b..476b1a1 100644
--- a/data_grammar/dataset_generator.py
+++ b/data_grammar/dataset_generator.py
@@ -27,7 +27,7 @@
# Add type: ignore for dotenv import since it's an optional dependency
try:
- from dotenv import load_dotenv # type: ignore
+ from dotenv import load_dotenv
load_dotenv()
except ImportError:
diff --git a/data_grammar/rlhf_generation/__init__.py b/data_grammar/rlhf_generation/__init__.py
new file mode 100644
index 0000000..ecc99e7
--- /dev/null
+++ b/data_grammar/rlhf_generation/__init__.py
@@ -0,0 +1 @@
+"""Module for generating datasets trough methods inspired by RLHF."""
diff --git a/data_grammar/rlhf_generation/output/dataset_path.json b/data_grammar/rlhf_generation/output/dataset_path.json
new file mode 100644
index 0000000..deea0ff
--- /dev/null
+++ b/data_grammar/rlhf_generation/output/dataset_path.json
@@ -0,0 +1,98 @@
+[
+ {
+ "layman_task": "Find scrap parts and pick them up",
+ "technical_task": "if you detect a scrap part then pick up the part, or otherwise then search randomly.",
+ "spoon_task": "if you detect a scrap part then pick up the part, or otherwise then walk randomly.",
+ "tree": "\n \n \n is_scrap_part_detected\n pick_up_part\n \n state_random_walk\n \n"
+ },
+ {
+ "layman_task": "Find good parts and pick them up",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n"
+ },
+ {
+ "layman_task": "Find good parts and take them to the base",
+ "technical_task": "if you are holding a good part then go to the base, or if you detect a good part then pick up the part, or otherwise then go to the source.",
+ "spoon_task": "if you are holding a good part then seek the base area, or if you detect a good part then pick up the part, or otherwise then seek the source area.",
+ "tree": "\n \n \n is_agent_holding_good_part\n state_seek_base_area\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n"
+ },
+ {
+ "layman_task": "Find scrap parts and take them to the waste",
+ "technical_task": "if you detect a scrap part then pick up the part, or if you are holding a scrap part then go to the waste, or if you are in the waste then drop the part, or otherwise then search randomly.",
+ "spoon_task": "if you detect a scrap part then pick up the part, or if you are holding a scrap part then seek the waste area, or if you are in the waste area then drop the part, or otherwise then walk randomly.",
+ "tree": "\n \n \n is_scrap_part_detected\n pick_up_part\n \n \n is_agent_holding_scrap_part\n state_seek_waste_area\n \n \n is_agent_in_waste_area\n drop_part\n \n state_random_walk\n \n"
+ },
+ {
+ "layman_task": "Find good parts and pick them up",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n"
+ },
+ {
+ "layman_task": "find good parts and pick them up",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n"
+ },
+ {
+ "layman_task": "Find good parts and pick them up",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n"
+ },
+ {
+ "layman_task": "Find scrap parts and pick them up",
+ "technical_task": "if you detect a scrap part then pick up the part, or otherwise then search randomly.",
+ "spoon_task": "if you detect a scrap part then pick up the part, or otherwise then walk randomly.",
+ "tree": "\n \n \n is_scrap_part_detected\n pick_up_part\n \n state_random_walk\n \n"
+ },
+ {
+ "layman_task": "Find good parts and pick them up nigga",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n"
+ },
+ {
+ "layman_task": "Find good parts and pick them up",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n"
+ },
+ {
+ "layman_task": "Find good parts and pick them up",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n"
+ },
+ {
+ "layman_task": "Find good parts and pick them up",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n"
+ },
+ {
+ "layman_task": "Pick up scrap parts and pick them up",
+ "technical_task": "if you detect a scrap part then pick up the part, or otherwise then search randomly.",
+ "spoon_task": "if you detect a scrap part then pick up the part, or otherwise then walk randomly.",
+ "tree": "\n \n \n is_scrap_part_detected\n pick_up_part\n \n state_random_walk\n \n"
+ },
+ {
+ "layman_task": "Find good parts and pick them up",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then go to the source.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then seek the source area.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_seek_source_area\n \n"
+ },
+ {
+ "layman_task": "Find good parts and pick them up",
+ "technical_task": "if you detect a good part then pick up the part, or otherwise then search randomly.",
+ "spoon_task": "if you detect a good part then pick up the part, or otherwise then walk randomly.",
+ "tree": "\n \n \n is_good_part_detected\n pick_up_part\n \n state_random_walk\n \n"
+ },
+ {
+ "layman_task": "Find scrap parts and picm them up and take them to the waste area and drop them there",
+ "technical_task": "if you are holding a scrap part and if you are in the waste then drop the part, or if you are holding a scrap part then go to the waste, or otherwise if you detect a scrap part then pick up the part.",
+ "spoon_task": "if you are holding a scrap part and if you are in the waste area then drop the part, or if you are holding a scrap part then seek the waste area, or otherwise if you detect a scrap part then pick up the part.",
+ "tree": "\n \n \n is_agent_holding_scrap_part\n is_agent_in_waste_area\n drop_part\n \n \n is_agent_holding_scrap_part\n state_seek_waste_area\n \n \n is_scrap_part_detected\n pick_up_part\n \n \n"
+ }
+]
\ No newline at end of file
diff --git a/data_grammar/rlhf_generation/rlhf_unified.py b/data_grammar/rlhf_generation/rlhf_unified.py
new file mode 100644
index 0000000..d3c33ea
--- /dev/null
+++ b/data_grammar/rlhf_generation/rlhf_unified.py
@@ -0,0 +1,1206 @@
+"""Unified RLHF-inspired dataset generation interface using Tkinter GUI."""
+
+import os
+import threading
+import tkinter as tk
+from tkinter import scrolledtext, ttk
+from typing import Any, Dict, List, Optional, Type
+
+from dotenv import load_dotenv
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_openai import ChatOpenAI
+from pydantic import BaseModel, Field
+from vi import Agent
+
+from agent_control import RobotAgent, RobotEnvironment
+from data_grammar.rlhf_generation.utils.prompt_builder import PromptBuilder
+from data_grammar.rlhf_generation.utils.run_robot_sim import run_robot_sim
+from data_grammar.rlhf_generation.utils.save_data_point import save_datapoint
+from tree_parser import AgentDocstringParser, BehaviorTreeGrammarValidator
+from tree_parser.primitives_validator import validate_primitives
+
+load_dotenv()
+
+openai_api_key = os.getenv("OPENAI_API_KEY")
+
+
+class GraphState(BaseModel):
+ """Represents the state of our graph."""
+
+ task_definition: str = Field(
+ description="The natural language definition of the task to be executed"
+ )
+ task_metrics_goal: str = Field(
+ description="The user defined metrics the tree should achieve in the simulation"
+ )
+ behaviour_tree: Optional[str] = Field(
+ default=None, description="The bahavior tree returned by the LLM"
+ )
+ passed_validator: bool = Field(
+ default=False,
+ description="Whether the tree passed both the grammar and primitive validator checks",
+ )
+ validator_feedback: Optional[str] = Field(
+ default=None,
+ description="If the tree doesnt pass, the feedback of both the grammar and primitive validator checks",
+ )
+ task_metrics_result: Optional[Dict[str, Any]] = Field(
+ default=None,
+ description="The metrics returned by executing the tree on the simulator",
+ )
+ human_feedback: Optional[str] = Field(
+ default=None,
+ description="The feedback of the human on the task execution that will be given to the LLM in case of retries",
+ )
+ dataset_size: int = Field(description="The number of samples generated")
+ dataset_path: str = Field(description="The path to the dataset")
+ dataset_size_goal: int = Field(description="The goal number of samples to generate")
+
+
+# ------------------------------ Unified UI Class ------------------------------
+class UnifiedRLHFUI:
+ """Unified RLHF dataset generation interface with persistent GUI."""
+
+ def __init__(
+ self,
+ dataset_path: str,
+ dataset_size_goal: int,
+ agent_class: Optional[Type[Agent]] = None,
+ grammar_rules: Optional[Dict[str, Any]] = None,
+ environment_class: Optional[Type[Any]] = None,
+ environment_kwargs: Optional[Dict[str, Any]] = None,
+ config_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ """Initialize the unified RLHF UI with the given parameters.
+
+ Args:
+ dataset_path: Path to the dataset file.
+ dataset_size_goal: Target number of datapoints to generate.
+ agent_class: The agent class to use for behavior tree validation.
+ grammar_rules: Custom grammar rules for behavior tree validation.
+ environment_class: Custom environment class for simulation.
+ environment_kwargs: Arguments to pass to the environment constructor.
+ config_kwargs: Configuration arguments for the agent.
+ """
+ self.dataset_path = dataset_path
+ self.dataset_size_goal = dataset_size_goal
+ self.agent_class = agent_class or RobotAgent
+ self.grammar_rules = grammar_rules or self._get_default_grammar_rules()
+ self.environment_class = environment_class or RobotEnvironment
+ default_env_kwargs = {"headless": True}
+ if environment_kwargs:
+ default_env_kwargs.update(environment_kwargs)
+ self.environment_kwargs = default_env_kwargs
+ self.config_kwargs = config_kwargs or {}
+
+ self.current_state: Optional[GraphState] = None
+ self.workflow_running = False
+ self._feedback_mode: bool = False
+
+ # Create main window
+ self.root = tk.Tk()
+ agent_name = self.agent_class.__name__ if self.agent_class else "Unknown Agent"
+ self.root.title(f"RLHF Dataset Generation - {agent_name}")
+ self.root.geometry("1000x700")
+ self.root.resizable(True, True)
+
+ # Initialize workflow components
+ self.setup_workflow()
+ self.create_ui()
+ self.reset_workflow()
+
+ def _get_default_grammar_rules(self) -> Dict[str, Any]:
+ """Get default grammar rules for behavior tree validation.
+
+ Returns:
+ Dictionary containing default grammar rules for behavior tree validation.
+ """
+ return {
+ "B": [["b", ["SEL"]], ["b", ["SEQ"]]],
+ "SEL": [["sel", ["SEQn", "As"]], ["sel", ["SEQn"]]],
+ "SEQn": [["SEQ", "SEQn"], ["SEQ"]],
+ "SEQ": [["seq", ["Pn", "A"]], ["seq", ["As", "Pn", "A"]]],
+ "b": ["BehaviorTree", ["children_nodes"]],
+ "sel": ["Selector", ["children_nodes"]],
+ "seq": ["Sequence", ["children_nodes"]],
+ "A": [["aa", "sa"], ["aa"], ["sa"]],
+ "As": [["aa"], ["sa"]],
+ "aa": ["ActuatorAction"],
+ "sa": ["StateAction"],
+ "Pn": [["p", "Pn"], ["p"], []],
+ "p": ["Condition"],
+ }
+
+ def setup_workflow(self) -> None:
+ """Initialize the workflow graph and LLM components."""
+ self.prompt_builder = PromptBuilder(self.agent_class)
+ self.system_prompt = self.prompt_builder.build_system_prompt()
+
+ self.tree_generator_prompt = ChatPromptTemplate.from_messages(
+ [("system", self.system_prompt), ("placeholder", "{user_prompt}")]
+ )
+
+ self.tree_generator_llm = ChatOpenAI(model="gpt-4o", temperature=0)
+
+ # Initialize agent doc parser for extracting node information
+ if self.agent_class:
+ self.agent_doc_parser: Optional[AgentDocstringParser] = (
+ AgentDocstringParser(self.agent_class)
+ )
+ self.agent_config: Optional[Dict[str, Any]] = (
+ self.agent_doc_parser.extract_docstring_config()
+ )
+ else:
+ self.agent_doc_parser = None
+ self.agent_config = None
+
+ def create_ui(self) -> None:
+ """Create the main UI with tabs."""
+ # Create notebook for tabs
+ self.notebook = ttk.Notebook(self.root)
+ self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
+
+ # Create tabs
+ self.create_main_tab()
+ self.create_nodes_tab()
+ self.create_dataset_tab()
+
+ def create_main_tab(self) -> None:
+ """Create the main workflow tab."""
+ self.main_frame = ttk.Frame(self.notebook)
+ self.notebook.add(self.main_frame, text="Main Workflow")
+
+ # Create main layout
+ main_container = ttk.Frame(self.main_frame)
+ main_container.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
+
+ # Left side - Controls and inputs
+ left_frame = ttk.Frame(main_container)
+ left_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=False, padx=(0, 10))
+ left_frame.config(width=400)
+
+ # Right side - Information display
+ right_frame = ttk.Frame(main_container)
+ right_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
+
+ self.create_left_panel(left_frame)
+ self.create_right_panel(right_frame)
+
+ def create_left_panel(self, parent: tk.Widget) -> tk.Widget:
+ """Create the left control panel.
+
+ Args:
+ parent: The parent widget to attach the panel to.
+
+ Returns:
+ The created left panel widget.
+ """
+ # Dataset info
+ info_frame = ttk.LabelFrame(parent, text="Dataset Information", padding=10)
+ info_frame.pack(fill=tk.X, pady=(0, 10))
+
+ self.dataset_size_label = ttk.Label(
+ info_frame, text="Current Size: 0", font=("Arial", 10)
+ )
+ self.dataset_size_label.pack(anchor=tk.W)
+
+ self.dataset_goal_label = ttk.Label(
+ info_frame,
+ text=f"Target Size: {self.dataset_size_goal}",
+ font=("Arial", 10),
+ )
+ self.dataset_goal_label.pack(anchor=tk.W)
+
+ self.progress_label = ttk.Label(
+ info_frame, text="Progress: 0.0%", font=("Arial", 10, "bold")
+ )
+ self.progress_label.pack(anchor=tk.W)
+
+ # Task input
+ task_frame = ttk.LabelFrame(parent, text="Task Definition", padding=10)
+ task_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10))
+
+ ttk.Label(
+ task_frame, text="Task Description:", font=("Arial", 10, "bold")
+ ).pack(anchor=tk.W)
+ self.task_entry = tk.Text(task_frame, height=6, wrap=tk.WORD)
+ self.task_entry.pack(fill=tk.BOTH, expand=True, pady=(5, 10))
+
+ ttk.Label(task_frame, text="Metrics Goal:", font=("Arial", 10, "bold")).pack(
+ anchor=tk.W
+ )
+ self.metrics_entry = tk.Text(task_frame, height=4, wrap=tk.WORD)
+ self.metrics_entry.pack(fill=tk.BOTH, expand=True, pady=(5, 10))
+
+ # Feedback input (initially hidden)
+ self.feedback_frame = ttk.LabelFrame(parent, text="Human Feedback", padding=10)
+ ttk.Label(
+ self.feedback_frame,
+ text="Feedback for improvement:",
+ font=("Arial", 10, "bold"),
+ ).pack(anchor=tk.W)
+ self.feedback_entry = tk.Text(self.feedback_frame, height=4, wrap=tk.WORD)
+ self.feedback_entry.pack(fill=tk.BOTH, expand=True, pady=(5, 0))
+
+ # Buttons frame
+ button_frame = ttk.Frame(parent)
+ button_frame.pack(fill=tk.X, pady=10)
+
+ self.prompt_button = ttk.Button(
+ button_frame, text="Generate Tree", command=self.start_generation
+ )
+ self.prompt_button.pack(side=tk.LEFT, padx=(0, 5))
+
+ self.run_sim_button = ttk.Button(
+ button_frame,
+ text="Run Simulation",
+ command=self.run_simulation,
+ state=tk.DISABLED,
+ )
+ self.run_sim_button.pack(side=tk.LEFT, padx=5)
+
+ self.feedback_button = ttk.Button(
+ button_frame,
+ text="Give Feedback",
+ command=self.give_feedback,
+ state=tk.DISABLED,
+ )
+ self.feedback_button.pack(side=tk.LEFT, padx=5)
+
+ self.save_button = ttk.Button(
+ button_frame,
+ text="Save Datapoint",
+ command=self.save_datapoint,
+ state=tk.DISABLED,
+ )
+ self.save_button.pack(side=tk.LEFT, padx=5)
+
+ return parent
+
+ def create_right_panel(self, parent: tk.Widget) -> tk.Widget:
+ """Create the right information display panel.
+
+ Args:
+ parent: The parent widget to attach the panel to.
+
+ Returns:
+ The created right panel widget.
+ """
+ # Status display
+ status_frame = ttk.LabelFrame(parent, text="Status", padding=10)
+ status_frame.pack(fill=tk.X, pady=(0, 10))
+
+ self.status_label = ttk.Label(
+ status_frame,
+ text="Ready to start",
+ font=("Arial", 12, "bold"),
+ foreground="blue",
+ )
+ self.status_label.pack()
+
+ # Generated tree display
+ tree_frame = ttk.LabelFrame(parent, text="Generated Behavior Tree", padding=10)
+ tree_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10))
+
+ self.tree_display = scrolledtext.ScrolledText(
+ tree_frame, height=12, wrap=tk.WORD, state=tk.DISABLED
+ )
+ self.tree_display.pack(fill=tk.BOTH, expand=True)
+
+ # Metrics/Feedback display
+ metrics_frame = ttk.LabelFrame(parent, text="Validation & Metrics", padding=10)
+ metrics_frame.pack(fill=tk.BOTH, expand=True)
+
+ self.metrics_display = scrolledtext.ScrolledText(
+ metrics_frame, height=8, wrap=tk.WORD, state=tk.DISABLED
+ )
+ self.metrics_display.pack(fill=tk.BOTH, expand=True)
+
+ return parent
+
+ def create_nodes_tab(self) -> None:
+ """Create the nodes information tab."""
+ self.nodes_frame = ttk.Frame(self.notebook)
+ self.notebook.add(self.nodes_frame, text="Available Nodes")
+
+ # Main container
+ main_container = ttk.Frame(self.nodes_frame)
+ main_container.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
+
+ # Title
+ title_frame = ttk.Frame(main_container)
+ title_frame.pack(fill=tk.X, pady=(0, 10))
+
+ ttk.Label(
+ title_frame, text="Available Agent Nodes", font=("Arial", 16, "bold")
+ ).pack(side=tk.LEFT)
+
+ # Create notebook for different node types
+ self.nodes_notebook = ttk.Notebook(main_container)
+ self.nodes_notebook.pack(fill=tk.BOTH, expand=True)
+
+ # Create tabs for each node type
+ self.create_conditions_tab()
+ self.create_actuator_actions_tab()
+ self.create_state_actions_tab()
+
+ def create_conditions_tab(self) -> None:
+ """Create tab for condition nodes."""
+ conditions_frame = ttk.Frame(self.nodes_notebook)
+ self.nodes_notebook.add(conditions_frame, text="Conditions")
+
+ # Scrollable text area
+ text_frame = ttk.Frame(conditions_frame)
+ text_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
+
+ self.conditions_display = scrolledtext.ScrolledText(
+ text_frame, wrap=tk.WORD, state=tk.DISABLED
+ )
+ self.conditions_display.pack(fill=tk.BOTH, expand=True)
+
+ # Populate with condition nodes
+ self.populate_conditions()
+
+ def create_actuator_actions_tab(self) -> None:
+ """Create tab for actuator action nodes."""
+ actuator_frame = ttk.Frame(self.nodes_notebook)
+ self.nodes_notebook.add(actuator_frame, text="Actuator Actions")
+
+ # Scrollable text area
+ text_frame = ttk.Frame(actuator_frame)
+ text_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
+
+ self.actuator_display = scrolledtext.ScrolledText(
+ text_frame, wrap=tk.WORD, state=tk.DISABLED
+ )
+ self.actuator_display.pack(fill=tk.BOTH, expand=True)
+
+ # Populate with actuator action nodes
+ self.populate_actuator_actions()
+
+ def create_state_actions_tab(self) -> None:
+ """Create tab for state action nodes."""
+ state_frame = ttk.Frame(self.nodes_notebook)
+ self.nodes_notebook.add(state_frame, text="State Actions")
+
+ # Scrollable text area
+ text_frame = ttk.Frame(state_frame)
+ text_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
+
+ self.state_display = scrolledtext.ScrolledText(
+ text_frame, wrap=tk.WORD, state=tk.DISABLED
+ )
+ self.state_display.pack(fill=tk.BOTH, expand=True)
+
+ # Populate with state action nodes
+ self.populate_state_actions()
+
+ def populate_conditions(self) -> None:
+ """Populate conditions tab with available condition nodes."""
+ self.conditions_display.config(state=tk.NORMAL)
+ self.conditions_display.delete(1.0, tk.END)
+
+ # Configure text tags for formatting
+ self.conditions_display.tag_configure("title", font=("Arial", 14, "bold"))
+ self.conditions_display.tag_configure("node_name", font=("Arial", 13, "bold"))
+ self.conditions_display.tag_configure("description", font=("Arial", 10))
+
+ if self.agent_config is None:
+ self.conditions_display.insert(
+ tk.END, "No agent configuration available.", "title"
+ )
+ self.conditions_display.config(state=tk.DISABLED)
+ return
+
+ conditions = self.agent_config.get("conditions", [])
+
+ if conditions:
+ # Insert title
+ self.conditions_display.insert(
+ tk.END, "Available Condition Nodes:\n", "title"
+ )
+ self.conditions_display.insert(tk.END, "=" * 50 + "\n\n")
+
+ for i, condition in enumerate(conditions, 1):
+ self.conditions_display.insert(
+ tk.END, f"{i}. {condition}\n", "node_name"
+ )
+ self.conditions_display.insert(tk.END, "-" * 30 + "\n")
+
+ # Try to get more info about the condition from the agent
+ try:
+ if hasattr(self.agent_class, condition):
+ method = getattr(self.agent_class, condition)
+ if hasattr(method, "__doc__") and method.__doc__:
+ self.conditions_display.insert(
+ tk.END,
+ f"Description: {method.__doc__.strip()}\n",
+ "description",
+ )
+ else:
+ self.conditions_display.insert(
+ tk.END,
+ "Description: No documentation available\n",
+ "description",
+ )
+ else:
+ self.conditions_display.insert(
+ tk.END,
+ "Description: Method not found in agent class\n",
+ "description",
+ )
+ except Exception as e:
+ self.conditions_display.insert(
+ tk.END,
+ f"Description: Error retrieving info - {str(e)}\n",
+ "description",
+ )
+
+ self.conditions_display.insert(tk.END, "\n")
+ else:
+ self.conditions_display.insert(tk.END, "No condition nodes found.", "title")
+
+ self.conditions_display.config(state=tk.DISABLED)
+
+ def populate_actuator_actions(self) -> None:
+ """Populate actuator actions tab with available actuator action nodes."""
+ self.actuator_display.config(state=tk.NORMAL)
+ self.actuator_display.delete(1.0, tk.END)
+
+ # Configure text tags for formatting
+ self.actuator_display.tag_configure("title", font=("Arial", 14, "bold"))
+ self.actuator_display.tag_configure("node_name", font=("Arial", 13, "bold"))
+ self.actuator_display.tag_configure("description", font=("Arial", 10))
+
+ if self.agent_config is None:
+ self.actuator_display.insert(
+ tk.END, "No agent configuration available.", "title"
+ )
+ self.actuator_display.config(state=tk.DISABLED)
+ return
+
+ actuator_actions = self.agent_config.get("actuator_actions", [])
+
+ if actuator_actions:
+ self.actuator_display.insert(
+ tk.END, "Available Actuator Action Nodes:\n", "title"
+ )
+ self.actuator_display.insert(tk.END, "=" * 50 + "\n\n")
+
+ for i, action in enumerate(actuator_actions, 1):
+ self.actuator_display.insert(tk.END, f"{i}. {action}\n", "node_name")
+ self.actuator_display.insert(tk.END, "-" * 30 + "\n")
+
+ # Try to get more info about the action from the agent
+ try:
+ if hasattr(self.agent_class, action):
+ method = getattr(self.agent_class, action)
+ if hasattr(method, "__doc__") and method.__doc__:
+ self.actuator_display.insert(
+ tk.END,
+ f"Description: {method.__doc__.strip()}\n",
+ "description",
+ )
+ else:
+ self.actuator_display.insert(
+ tk.END,
+ "Description: No documentation available\n",
+ "description",
+ )
+ else:
+ self.actuator_display.insert(
+ tk.END,
+ "Description: Method not found in agent class\n",
+ "description",
+ )
+ except Exception as e:
+ self.actuator_display.insert(
+ tk.END,
+ f"Description: Error retrieving info - {str(e)}\n",
+ "description",
+ )
+
+ self.actuator_display.insert(tk.END, "\n")
+ else:
+ self.actuator_display.insert(
+ tk.END, "No actuator action nodes found.", "title"
+ )
+
+ self.actuator_display.config(state=tk.DISABLED)
+
+ def populate_state_actions(self) -> None:
+ """Populate state actions tab with available state action nodes."""
+ self.state_display.config(state=tk.NORMAL)
+ self.state_display.delete(1.0, tk.END)
+
+ # Configure text tags for formatting
+ self.state_display.tag_configure("title", font=("Arial", 14, "bold"))
+ self.state_display.tag_configure("node_name", font=("Arial", 13, "bold"))
+ self.state_display.tag_configure("description", font=("Arial", 10))
+
+ if self.agent_config is None:
+ self.state_display.insert(
+ tk.END, "No agent configuration available.", "title"
+ )
+ self.state_display.config(state=tk.DISABLED)
+ return
+
+ state_actions = self.agent_config.get("state_actions", [])
+
+ if state_actions:
+ self.state_display.insert(
+ tk.END, "Available State Action Nodes:\n", "title"
+ )
+ self.state_display.insert(tk.END, "=" * 50 + "\n\n")
+
+ for i, action in enumerate(state_actions, 1):
+ self.state_display.insert(tk.END, f"{i}. {action}\n", "node_name")
+ self.state_display.insert(tk.END, "-" * 30 + "\n")
+
+ # Try to get more info about the action from the agent
+ try:
+ if hasattr(self.agent_class, action):
+ method = getattr(self.agent_class, action)
+ if hasattr(method, "__doc__") and method.__doc__:
+ self.state_display.insert(
+ tk.END,
+ f"Description: {method.__doc__.strip()}\n",
+ "description",
+ )
+ else:
+ self.state_display.insert(
+ tk.END,
+ "Description: No documentation available\n",
+ "description",
+ )
+ else:
+ self.state_display.insert(
+ tk.END,
+ "Description: Method not found in agent class\n",
+ "description",
+ )
+ except Exception as e:
+ self.state_display.insert(
+ tk.END,
+ f"Description: Error retrieving info - {str(e)}\n",
+ "description",
+ )
+
+ self.state_display.insert(tk.END, "\n")
+ else:
+ self.state_display.insert(tk.END, "No state action nodes found.", "title")
+
+ self.state_display.config(state=tk.DISABLED)
+
+ def create_dataset_tab(self) -> None:
+ """Create the dataset exploration tab."""
+ self.dataset_frame = ttk.Frame(self.notebook)
+ self.notebook.add(self.dataset_frame, text="Dataset Explorer")
+
+ # Main container
+ main_container = ttk.Frame(self.dataset_frame)
+ main_container.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
+
+ # Title and refresh button
+ title_frame = ttk.Frame(main_container)
+ title_frame.pack(fill=tk.X, pady=(0, 10))
+
+ ttk.Label(
+ title_frame, text="Dataset Explorer", font=("Arial", 16, "bold")
+ ).pack(side=tk.LEFT)
+ self.refresh_button = ttk.Button(
+ title_frame, text="Refresh", command=self.refresh_dataset_list
+ )
+ self.refresh_button.pack(side=tk.RIGHT)
+
+ # Left side - Datapoint list
+ left_frame = ttk.LabelFrame(main_container, text="Datapoints", padding=10)
+ left_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=False, padx=(0, 10))
+ left_frame.config(width=300)
+
+ # Listbox with scrollbar for datapoints
+ list_frame = ttk.Frame(left_frame)
+ list_frame.pack(fill=tk.BOTH, expand=True)
+
+ self.datapoint_listbox = tk.Listbox(list_frame, width=40)
+ self.datapoint_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
+ self.datapoint_listbox.bind("<>", self.on_datapoint_select)
+
+ list_scrollbar = ttk.Scrollbar(
+ list_frame, orient=tk.VERTICAL, command=self.datapoint_listbox.yview
+ )
+ list_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
+ self.datapoint_listbox.config(yscrollcommand=list_scrollbar.set)
+
+ # Right side - Datapoint details
+ right_frame = ttk.LabelFrame(
+ main_container, text="Datapoint Details", padding=10
+ )
+ right_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
+
+ # Layman task display
+ ttk.Label(right_frame, text="Layman Task:", font=("Arial", 12, "bold")).pack(
+ anchor=tk.W, pady=(0, 5)
+ )
+ self.layman_display = scrolledtext.ScrolledText(
+ right_frame, height=6, wrap=tk.WORD, state=tk.DISABLED
+ )
+ self.layman_display.pack(fill=tk.X, pady=(0, 15))
+
+ # Behavior tree display
+ ttk.Label(right_frame, text="Behavior Tree:", font=("Arial", 12, "bold")).pack(
+ anchor=tk.W, pady=(0, 5)
+ )
+ self.tree_details_display = scrolledtext.ScrolledText(
+ right_frame, height=15, wrap=tk.WORD, state=tk.DISABLED
+ )
+ self.tree_details_display.pack(fill=tk.BOTH, expand=True)
+
+ # Load initial data
+ self.dataset_data: List[Dict[str, Any]] = []
+ self.refresh_dataset_list()
+
+ def reset_workflow(self) -> None:
+ """Reset the workflow to initial state."""
+ current_size = self.get_current_dataset_size()
+
+ self.current_state = GraphState(
+ task_definition="",
+ task_metrics_goal="",
+ dataset_size=current_size,
+ dataset_path=self.dataset_path,
+ dataset_size_goal=self.dataset_size_goal,
+ )
+ self.workflow_running = False
+
+ # Clear input fields for fresh start
+ self.task_entry.delete(1.0, tk.END)
+ self.metrics_entry.delete(1.0, tk.END)
+ if hasattr(self, "feedback_entry"):
+ self.feedback_entry.delete(1.0, tk.END)
+
+ self.update_dataset_info()
+ self.update_status("Ready to start", "blue")
+ self.clear_displays()
+ self.reset_buttons()
+
+ def get_current_dataset_size(self) -> int:
+ """Get current dataset size from file.
+
+ Returns:
+ The current number of datapoints in the dataset file.
+ """
+ import json
+
+ if os.path.exists(self.dataset_path):
+ try:
+ with open(self.dataset_path, "r") as f:
+ dataset = json.load(f)
+ return len(dataset)
+ except (json.JSONDecodeError, FileNotFoundError):
+ return 0
+ return 0
+
+ def update_dataset_info(self) -> None:
+ """Update dataset information display."""
+ current_size = self.get_current_dataset_size()
+ progress = (
+ (current_size / self.dataset_size_goal) * 100
+ if self.dataset_size_goal > 0
+ else 0
+ )
+
+ self.dataset_size_label.config(text=f"Current Size: {current_size}")
+ self.progress_label.config(text=f"Progress: {progress:.1f}%")
+
+ def update_status(self, message: str, color: str = "black") -> None:
+ """Update status message.
+
+ Args:
+ message: The status message to display.
+ color: The color of the status message.
+ """
+ self.status_label.config(text=message, foreground=color)
+ self.root.update()
+
+ def clear_displays(self) -> None:
+ """Clear all display areas."""
+ self.tree_display.config(state=tk.NORMAL)
+ self.tree_display.delete(1.0, tk.END)
+ self.tree_display.config(state=tk.DISABLED)
+
+ self.metrics_display.config(state=tk.NORMAL)
+ self.metrics_display.delete(1.0, tk.END)
+ self.metrics_display.config(state=tk.DISABLED)
+
+ def reset_buttons(self) -> None:
+ """Reset button states."""
+ self.prompt_button.config(state=tk.NORMAL)
+ self.run_sim_button.config(state=tk.DISABLED)
+ self.feedback_button.config(state=tk.NORMAL)
+ self.save_button.config(state=tk.NORMAL)
+ self.feedback_frame.pack_forget()
+
+ def start_generation(self) -> None:
+ """Start the tree generation process."""
+ if self.workflow_running:
+ return
+
+ # Get input values
+ task_definition = self.task_entry.get(1.0, tk.END).strip()
+ task_metrics_goal = self.metrics_entry.get(1.0, tk.END).strip()
+
+ if not task_definition:
+ self.update_status("Please fill in the task definition", "red")
+ return
+
+ # Update state
+ if self.current_state:
+ self.current_state.task_definition = task_definition
+ self.current_state.task_metrics_goal = task_metrics_goal
+
+ # Get feedback if available
+ if (
+ hasattr(self, "_feedback_mode")
+ and self._feedback_mode
+ and self.current_state
+ ):
+ human_feedback = self.feedback_entry.get(1.0, tk.END).strip()
+ self.current_state.human_feedback = (
+ human_feedback if human_feedback else None
+ )
+ self._feedback_mode = False
+ self.feedback_frame.pack_forget()
+ else:
+ if self.current_state:
+ self.current_state.human_feedback = None
+
+ # Start generation in separate thread
+ self.workflow_running = True
+ self.prompt_button.config(state=tk.DISABLED)
+ self.update_status("Generating behavior tree...", "orange")
+
+ threading.Thread(target=self.generate_tree, daemon=True).start()
+
+ def generate_tree(self) -> None:
+ """Generate behavior tree using LLM."""
+ try:
+ if not self.current_state:
+ self.root.after(
+ 0, self.on_generation_error, "No current state available"
+ )
+ return
+
+ # Prepare base prompt
+ prompt = f"Please generate a behaviour tree for the following task: {self.current_state.task_definition}"
+
+ # Add metrics goal if provided
+ if self.current_state.task_metrics_goal:
+ prompt += f"\nThe task metrics to achieve are: {self.current_state.task_metrics_goal}"
+
+ # Add feedback section if there's human feedback
+ if self.current_state.human_feedback:
+ previous_tree = self.current_state.behaviour_tree or "No previous tree"
+ prompt += f"\n\nYour previous attempt was:\n{previous_tree}\n\nHuman Feedback on Previous Attempt: {self.current_state.human_feedback}\n\nPlease incorporate this feedback to improve the behavior tree."
+
+ # Generate tree
+ class TreeGeneratorOutput(BaseModel):
+ behaviour_tree: str = Field(
+ description="The raw behaviour tree in XML format without any quotes or markdown formatting"
+ )
+
+ tree_gen_chain = (
+ self.tree_generator_prompt
+ | self.tree_generator_llm.with_structured_output(TreeGeneratorOutput)
+ )
+ result = tree_gen_chain.invoke({"user_prompt": [("user", prompt)]})
+
+ # Update UI in main thread
+ if hasattr(result, "behaviour_tree"):
+ self.root.after(0, self.on_tree_generated, result.behaviour_tree)
+ elif isinstance(result, dict) and "behaviour_tree" in result:
+ self.root.after(0, self.on_tree_generated, result["behaviour_tree"])
+ else:
+ self.root.after(
+ 0,
+ self.on_generation_error,
+ f"Unexpected result format: {type(result)}",
+ )
+
+ except Exception as e:
+ self.root.after(0, self.on_generation_error, str(e))
+
+ def on_tree_generated(self, behaviour_tree: str) -> None:
+ """Handle successful tree generation.
+
+ Args:
+ behaviour_tree: The generated behavior tree string.
+ """
+ if self.current_state:
+ self.current_state.behaviour_tree = behaviour_tree
+
+ # Display tree
+ self.tree_display.config(state=tk.NORMAL)
+ self.tree_display.delete(1.0, tk.END)
+ self.tree_display.insert(1.0, behaviour_tree)
+ self.tree_display.config(state=tk.DISABLED)
+
+ self.update_status("Tree generated! Validating...", "orange")
+
+ # Validate tree
+ threading.Thread(target=self.validate_tree, daemon=True).start()
+
+ def on_generation_error(self, error_msg: str) -> None:
+ """Handle tree generation errors.
+
+ Args:
+ error_msg: The error message to display.
+ """
+ self.update_status(f"Generation failed: {error_msg}", "red")
+
+ # Reset workflow state
+ self.workflow_running = False
+ self.prompt_button.config(state=tk.NORMAL)
+
+ def validate_tree(self) -> None:
+ """Validate the generated behavior tree."""
+ if not self.current_state or not self.current_state.behaviour_tree:
+ self.update_status("No tree to validate", "red")
+ return
+
+ self.update_status("Validating tree...", "orange")
+
+ threading.Thread(
+ target=lambda: self._validate_tree_thread(), daemon=True
+ ).start()
+
+ def _validate_tree_thread(self) -> None:
+ """Run tree validation in background thread."""
+ try:
+ if not self.current_state or not self.current_state.behaviour_tree:
+ self.root.after(
+ 0, self.on_validation_error, "No tree available for validation"
+ )
+ return
+
+ grammar_validator = BehaviorTreeGrammarValidator(self.grammar_rules)
+ passed_grammar, grammar_feedback = grammar_validator.validate_tree(
+ self.current_state.behaviour_tree
+ )
+ passed_primitive, primitive_feedback = validate_primitives(
+ self.current_state.behaviour_tree, self.agent_class
+ )
+
+ passed = passed_grammar and passed_primitive
+ feedback = ""
+ if not passed_grammar:
+ feedback += f"Grammar validation failed: {grammar_feedback}\n"
+ if not passed_primitive:
+ feedback += f"Primitive validation failed: {primitive_feedback}"
+
+ self.root.after(0, self.on_validation_complete, passed, feedback)
+ except Exception as e:
+ self.root.after(0, self.on_validation_error, str(e))
+
+ def on_validation_complete(self, passed: bool, feedback: str) -> None:
+ """Handle validation completion.
+
+ Args:
+ passed: Whether the validation passed.
+ feedback: The validation feedback message.
+ """
+ if self.current_state:
+ self.current_state.passed_validator = passed
+ self.current_state.validator_feedback = feedback
+
+ if passed:
+ self.update_status("Tree validation passed!", "green")
+ if self.current_state:
+ self.metrics_display.config(state=tk.NORMAL)
+ self.metrics_display.delete(1.0, tk.END)
+ self.metrics_display.insert(
+ 1.0, f"Desired Metrics:\n{self.current_state.task_metrics_goal}"
+ )
+ self.metrics_display.config(state=tk.DISABLED)
+ else:
+ self.update_status("Tree validation failed!", "red")
+ self.metrics_display.config(state=tk.NORMAL)
+ self.metrics_display.delete(1.0, tk.END)
+ self.metrics_display.insert(
+ 1.0,
+ f"Validation Feedback:\n{feedback}\n\nSimulation cannot run with invalid tree.",
+ )
+ self.metrics_display.config(state=tk.DISABLED)
+
+ # Enable buttons based on validation result
+ self.run_sim_button.config(state=tk.NORMAL if passed else tk.DISABLED)
+ self.feedback_button.config(state=tk.NORMAL)
+ self.save_button.config(state=tk.NORMAL)
+
+ # Reset workflow state
+ self.workflow_running = False
+ self.prompt_button.config(state=tk.NORMAL)
+
+ def on_validation_error(self, error_msg: str) -> None:
+ """Handle validation errors.
+
+ Args:
+ error_msg: The error message to display.
+ """
+ self.update_status(f"Validation failed: {error_msg}", "red")
+
+ # Reset workflow state
+ self.workflow_running = False
+ self.prompt_button.config(state=tk.NORMAL)
+
+ def run_simulation(self) -> None:
+ """Run simulation of the behavior tree."""
+ if not self.current_state or not self.current_state.passed_validator:
+ self.update_status("Cannot run simulation: tree validation failed", "red")
+ return
+
+ self.update_status("Running simulation...", "orange")
+ self.run_sim_button.config(state=tk.DISABLED)
+
+ threading.Thread(target=self.run_sim_thread, daemon=True).start()
+
+ def run_sim_thread(self) -> None:
+ """Run simulation in background thread.
+
+ Raises:
+ ValueError: If there is an error during simulation execution.
+ """
+ try:
+ if not self.current_state or not self.current_state.behaviour_tree:
+ raise ValueError("No valid tree available for simulation")
+
+ metrics_result = run_robot_sim(
+ self.current_state.behaviour_tree,
+ environment_class=self.environment_class,
+ environment_kwargs=self.environment_kwargs,
+ config_kwargs=self.config_kwargs,
+ )
+
+ self.root.after(0, self.on_simulation_complete, metrics_result)
+ except Exception as e:
+ self.root.after(0, self.on_simulation_error, str(e))
+
+ def on_simulation_complete(self, metrics_result: Dict[str, Any]) -> None:
+ """Handle simulation completion.
+
+ Args:
+ metrics_result: The simulation metrics and results.
+ """
+ if self.current_state:
+ self.current_state.task_metrics_result = metrics_result
+
+ # Update metrics display
+ self.metrics_display.config(state=tk.NORMAL)
+ self.metrics_display.delete(1.0, tk.END)
+
+ if self.current_state:
+ content = f"Desired Metrics:\n{self.current_state.task_metrics_goal}\n\n"
+ content += f"Achieved Metrics:\n{metrics_result}"
+ self.metrics_display.insert(1.0, content)
+
+ self.metrics_display.config(state=tk.DISABLED)
+ self.update_status("Simulation completed!", "green")
+
+ # Re-enable the simulation button
+ self.run_sim_button.config(state=tk.NORMAL)
+
+ # Force UI update to ensure responsiveness
+ self.root.update_idletasks()
+ self.root.update()
+
+ def on_simulation_error(self, error_msg: str) -> None:
+ """Handle simulation errors.
+
+ Args:
+ error_msg: The error message to display.
+ """
+ self.update_status(f"Simulation failed: {error_msg}", "red")
+
+ # Re-enable the simulation button on error
+ self.run_sim_button.config(state=tk.NORMAL)
+
+ # Force UI update to ensure responsiveness
+ self.root.update_idletasks()
+ self.root.update()
+
+ def give_feedback(self) -> None:
+ """Enable feedback mode for human input."""
+ if not self.current_state or not self.current_state.behaviour_tree:
+ self.update_status("No tree available for feedback", "red")
+ return
+
+ self._feedback_mode = True
+
+ # Show and setup the feedback frame
+ self.feedback_frame.pack(fill=tk.X, pady=(0, 10))
+ self.feedback_entry.delete(1.0, tk.END)
+ self.feedback_entry.focus()
+
+ self.update_status(
+ "Provide feedback and click 'Generate Tree' to retry", "blue"
+ )
+
+ def save_datapoint(self) -> None:
+ """Save current tree and metrics as a datapoint."""
+ if not self.current_state or not self.current_state.behaviour_tree:
+ self.update_status("No tree to save", "red")
+ return
+
+ try:
+ save_datapoint(
+ dataset_path=self.current_state.dataset_path,
+ task_description=self.current_state.task_definition,
+ tree_str=self.current_state.behaviour_tree,
+ agent_class=self.agent_class,
+ )
+
+ self.update_dataset_info()
+ self.update_status("Datapoint saved successfully!", "green")
+
+ # Check if goal is reached
+ if self.ask_continue_generation():
+ self.reset_workflow()
+ else:
+ self.update_status("Dataset generation completed!", "blue")
+ except Exception as e:
+ self.update_status(f"Failed to save datapoint: {e}", "red")
+
+ def ask_continue_generation(self) -> bool:
+ """Ask user if they want to continue generating data.
+
+ Returns:
+ True if the user wants to continue, False otherwise.
+ """
+ if not self.current_state:
+ return False
+
+ current_size = self.get_current_dataset_size()
+
+ if current_size >= self.dataset_size_goal:
+ import tkinter.messagebox as msgbox
+
+ result = msgbox.askyesno(
+ "Goal Reached",
+ f"Dataset size goal of {self.dataset_size_goal} reached! "
+ f"Current size: {current_size}. Continue generating?",
+ )
+ else:
+ import tkinter.messagebox as msgbox
+
+ result = msgbox.askyesno(
+ "Continue Generation",
+ f"Datapoint saved! ({current_size}/{self.dataset_size_goal})\nWould you like to continue generating more datapoints?",
+ )
+
+ if result:
+ self.reset_workflow()
+ else:
+ self.update_status("Generation complete", "blue")
+
+ # Refresh dataset explorer
+ self.refresh_dataset_list()
+
+ return result
+
+ def refresh_dataset_list(self) -> None:
+ """Refresh the dataset list in the dataset tab."""
+ if not hasattr(self, "datapoint_listbox"):
+ return
+
+ self.datapoint_listbox.delete(0, tk.END)
+
+ import json
+
+ try:
+ if os.path.exists(self.dataset_path):
+ with open(self.dataset_path, "r") as f:
+ dataset = json.load(f)
+ for i, datapoint in enumerate(dataset):
+ preview = datapoint.get("layman_task", "No task description")[
+ :50
+ ]
+ if len(preview) == 50:
+ preview += "..."
+ self.datapoint_listbox.insert(tk.END, f"{i+1}. {preview}")
+ except (json.JSONDecodeError, FileNotFoundError):
+ pass
+
+ self.clear_dataset_details()
+
+ def on_datapoint_select(self, event: Any) -> None:
+ """Handle datapoint selection in the list.
+
+ Args:
+ event: The selection event from the listbox.
+ """
+ try:
+ selection = self.datapoint_listbox.curselection() # type: ignore[no-untyped-call]
+ if not selection:
+ return
+
+ index = selection[0]
+ import json
+
+ with open(self.dataset_path, "r") as f:
+ dataset = json.load(f)
+ if index < len(dataset):
+ self.display_datapoint_details(dataset[index])
+ except Exception:
+ pass
+
+ def display_datapoint_details(self, datapoint: Dict[str, Any]) -> None:
+ """Display details of selected datapoint.
+
+ Args:
+ datapoint: The datapoint dictionary containing task and tree information.
+ """
+ # Display layman task
+ self.layman_display.config(state=tk.NORMAL)
+ self.layman_display.delete(1.0, tk.END)
+ self.layman_display.insert(
+ 1.0, datapoint.get("layman_task", "No task description")
+ )
+ self.layman_display.config(state=tk.DISABLED)
+
+ # Display formatted tree
+ self.tree_details_display.config(state=tk.NORMAL)
+ self.tree_details_display.delete(1.0, tk.END)
+
+ tree_content = datapoint.get("tree", "No tree available")
+ if tree_content:
+ # Format XML for better readability
+ try:
+ import xml.dom.minidom
+
+ dom = xml.dom.minidom.parseString(tree_content)
+ formatted_tree = dom.toprettyxml(indent=" ")
+ # Remove empty lines
+ formatted_tree = "\n".join(
+ [line for line in formatted_tree.split("\n") if line.strip()]
+ )
+ self.tree_details_display.insert(1.0, formatted_tree)
+ except Exception:
+ self.tree_details_display.insert(1.0, tree_content)
+
+ self.tree_details_display.config(state=tk.DISABLED)
+
+ def clear_dataset_details(self) -> None:
+ """Clear the dataset details display."""
+ if hasattr(self, "layman_display"):
+ self.layman_display.config(state=tk.NORMAL)
+ self.layman_display.delete(1.0, tk.END)
+ self.layman_display.config(state=tk.DISABLED)
+
+ if hasattr(self, "tree_details_display"):
+ self.tree_details_display.config(state=tk.NORMAL)
+ self.tree_details_display.delete(1.0, tk.END)
+ self.tree_details_display.config(state=tk.DISABLED)
+
+ def run(self) -> None:
+ """Start the application."""
+ self.root.mainloop()
diff --git a/data_grammar/rlhf_generation/utils/prompt_builder.py b/data_grammar/rlhf_generation/utils/prompt_builder.py
new file mode 100644
index 0000000..0ed4ef4
--- /dev/null
+++ b/data_grammar/rlhf_generation/utils/prompt_builder.py
@@ -0,0 +1,104 @@
+"""Class to build system prompt for the RLHF agent."""
+
+import os
+from typing import Dict, Type
+
+from vi import Agent
+
+
+class PromptBuilder:
+ """Builds system prompts for RLHF agents using XML templates and agent class behaviors."""
+
+ def __init__(self, agent_class: Type[Agent]) -> None:
+ """
+ Initialize the PromptBuilder with an agent class.
+
+ Args:
+ agent_class: The agent class to extract behaviors from
+ """
+ self.agent_class = agent_class
+
+ def load_xml_template(self, file_name: str) -> str:
+ """
+ Load an XML template from the prompt_techniques directory.
+
+ Args:
+ file_name: The name of the template file
+
+ Returns:
+ The contents of the template file
+ """
+ directory = "llm_interface/prompt_techniques"
+ file_path = os.path.join(directory, file_name)
+ with open(file_path, "r") as file:
+ return file.read()
+
+ def extract_agent_behaviors(self) -> Dict[str, str]:
+ """
+ Get all the behaviors from the agent class, extracting only Node Type and Description.
+
+ Returns:
+ Dictionary of function names and their processed docstrings (Node Type and Description only)
+ """
+ class_obj = self.agent_class
+ function_names_and_docstrings = {}
+
+ for func_name in dir(class_obj):
+ if (
+ callable(getattr(class_obj, func_name))
+ and not func_name.startswith("__")
+ and not func_name.startswith("update")
+ and not func_name.startswith("helper")
+ and getattr(class_obj, func_name).__qualname__.startswith(
+ class_obj.__name__ + "."
+ )
+ ):
+ func = getattr(class_obj, func_name)
+ if func.__doc__:
+ # Split docstring into lines and process
+ doc_lines = func.__doc__.strip().split("\n")
+ processed_doc = []
+
+ for line in doc_lines:
+ line = line.strip()
+ if line.startswith("Node Type:") or line.startswith(
+ "Description:"
+ ):
+ processed_doc.append(line)
+
+ # Join the processed lines back together
+ function_names_and_docstrings[func_name] = "\n".join(processed_doc)
+ else:
+ function_names_and_docstrings[func_name] = "No docstring found."
+
+ return function_names_and_docstrings
+
+ def build_system_prompt(self) -> str:
+ """
+ Build the complete system prompt by loading the XML template and replacing placeholders.
+
+ Returns:
+ The complete system prompt with all placeholders replaced
+ """
+ # Load and prepare system prompt
+ system_prompt = self.load_xml_template("system_prompt.xml")
+ bt_3_example = self.load_xml_template("bt_3.xml")
+ behaviors = self.extract_agent_behaviors()
+
+ # Format behaviors dictionary to avoid template variable conflicts
+ # Convert to a safe string representation and escape all curly braces
+ behaviors_str = "{{\n" # Start with escaped curly brace
+ for key, value in behaviors.items():
+ # Escape any curly braces in the content and format safely
+ safe_value = value.replace("{", "{{").replace("}", "}}")
+ safe_key = key.replace("{", "{{").replace("}", "}}")
+ behaviors_str += f' "{safe_key}": "{safe_value}",\n'
+ behaviors_str = (
+ behaviors_str.rstrip(",\n") + "\n}}"
+ ) # End with escaped curly brace
+
+ # Replace placeholders in system prompt
+ system_prompt = system_prompt.replace("{bt_3.xml}", bt_3_example)
+ system_prompt = system_prompt.replace("{BEHAVIORS}", behaviors_str)
+
+ return system_prompt
diff --git a/data_grammar/rlhf_generation/utils/run_robot_sim.py b/data_grammar/rlhf_generation/utils/run_robot_sim.py
new file mode 100644
index 0000000..a7ce216
--- /dev/null
+++ b/data_grammar/rlhf_generation/utils/run_robot_sim.py
@@ -0,0 +1,99 @@
+"""Run robot simulation with behavior tree and return metrics."""
+
+import os
+from typing import Any, Dict, Optional, Type
+
+from vi import Config, Window
+
+from agent_control import RobotEnvironment
+from tree_parser import save_behavior_tree_xml
+
+
+def run_robot_sim(
+ bt_str: str,
+ environment_class: Optional[Type[Any]] = None,
+ environment_kwargs: Optional[Dict[str, Any]] = None,
+ config_kwargs: Optional[Dict[str, Any]] = None,
+) -> Dict[str, Any]:
+ """Run a robot environment simulation with behavior trees.
+
+ Args:
+ bt_str: The behavior tree XML string
+ environment_class: The environment class to use (defaults to RobotEnvironment)
+ environment_kwargs: Dict of keyword arguments for the environment constructor
+ Example: {'n_agents': 10, 'n_parts': 15, 'task': 'collect', 'headless': False}
+ config_kwargs: Dict of keyword arguments for the Config constructor
+ Example: {'radius': 25, 'duration': 1000, 'window_size': 500, 'movement_speed': 1.0}
+
+ Returns:
+ Dict containing simulation metrics
+ """
+ # Use defaults if not provided
+ environment_class = environment_class or RobotEnvironment
+ environment_kwargs = environment_kwargs or {}
+ config_kwargs = config_kwargs or {}
+
+ # Set default config values if not provided
+ default_config = {
+ "radius": 25, # Agent detection radius
+ "visualise_chunks": True, # Show spatial partitioning
+ "window_size": 500, # Window size (will create square window)
+ "movement_speed": 1.0, # Agent movement speed
+ "duration": 1000, # Simulation duration in steps
+ }
+
+ # Merge default config with provided config
+ final_config = {**default_config, **config_kwargs}
+
+ # Handle window_size parameter - convert to Window object
+ if "window_size" in final_config:
+ window_size = final_config.pop("window_size")
+ final_config["window"] = Window.square(window_size)
+ elif "window" not in final_config:
+ # If neither window_size nor window is provided, use default
+ final_config["window"] = Window.square(500)
+
+ # Configuration for the simulation
+ config = Config(**final_config)
+
+ # Create a temp path for the behavior tree
+ bt_path = "./temp_bt.xml"
+ save_behavior_tree_xml(bt_str, bt_path)
+
+ # Set default environment values if not provided
+ default_environment = {
+ "config": config,
+ "bt_path": bt_path,
+ "n_agents": 10, # Number of robot agents
+ "n_parts": 10, # Number of parts in environment
+ "task": "collect", # Task description
+ "headless": False, # Show GUI (set to True for no GUI)
+ }
+
+ # Merge default environment with provided environment kwargs
+ # Note: provided kwargs override defaults, including config and bt_path
+ final_environment = {**default_environment, **environment_kwargs}
+ final_environment["config"] = config # Always use the constructed config
+ final_environment["bt_path"] = bt_path # Always use the temp bt_path
+
+ # Create the environment
+ environment = environment_class(**final_environment)
+ environment.setup()
+
+ try:
+ # Run the simulation and collect metrics
+ results = environment.run()
+ finally:
+ # Ensure pygame is properly cleaned up
+ try:
+ import pygame
+
+ pygame.display.quit()
+ pygame.quit()
+ except:
+ pass # Ignore errors if pygame is not available or already cleaned up
+
+ # Delete the temp behavior tree
+ os.remove(bt_path)
+
+ return results
diff --git a/data_grammar/rlhf_generation/utils/save_data_point.py b/data_grammar/rlhf_generation/utils/save_data_point.py
new file mode 100644
index 0000000..86c7592
--- /dev/null
+++ b/data_grammar/rlhf_generation/utils/save_data_point.py
@@ -0,0 +1,87 @@
+"""Save a datapoint in the given path after producing other prompt styles."""
+
+import json
+import os
+from typing import Any, Type
+
+from data_grammar.grammar_gen.node_translations import node_connectors
+from data_grammar.grammar_gen.tree_to_prompt import (
+ generate_spoon_prompt_from_string,
+ generate_technical_prompt_from_string,
+)
+from tree_parser.agent_doc_parser import AgentDocstringParser
+
+
+def save_datapoint(
+ dataset_path: str, task_description: str, tree_str: str, agent_class: Type[Any]
+) -> None:
+ """
+ Save a single datapoint to the dataset file after generating other prompt styles.
+
+ Args:
+ dataset_path: Path to the JSON dataset file (should end with .json)
+ task_description: The layman task description
+ tree_str: The behavior tree XML string
+ agent_class: The agent class to extract translations from
+ """
+ # Ensure the dataset path has .json extension
+ if not dataset_path.endswith(".json"):
+ dataset_path = dataset_path + ".json"
+
+ # Retrieve the translations
+ doc_parser = AgentDocstringParser(agent_class=agent_class)
+ extracted_info_dict = doc_parser.extract_docstring_config()
+ spoon_translations = extracted_info_dict["spoon_node_translations"]
+ technical_translations = extracted_info_dict["node_translations"]
+
+ # Produce the prompts
+ layman_prompt = task_description
+ try:
+ spoon_prompt = generate_spoon_prompt_from_string(
+ tree_string=tree_str,
+ spoon_translations=spoon_translations,
+ connectors=node_connectors,
+ )
+ except Exception as e:
+ spoon_prompt = f"Error generating spoon task: {e}"
+
+ try:
+ tech_prompt = generate_technical_prompt_from_string(
+ tree_string=tree_str,
+ translations=technical_translations,
+ connectors=node_connectors,
+ )
+ except Exception as e:
+ tech_prompt = f"Error generating technical task: {e}"
+
+ # Create the datapoint in the same format as the other generators
+ datapoint = {
+ "layman_task": layman_prompt,
+ "technical_task": tech_prompt,
+ "spoon_task": spoon_prompt,
+ "tree": tree_str,
+ }
+
+ # Load existing dataset or create new one
+ dataset = []
+ if os.path.exists(dataset_path):
+ try:
+ with open(dataset_path, "r") as json_file:
+ dataset = json.load(json_file)
+ except (json.JSONDecodeError, FileNotFoundError):
+ # If file exists but is corrupted or empty, start fresh
+ dataset = []
+
+ # Append the new datapoint
+ dataset.append(datapoint)
+
+ # Create directory if it doesn't exist
+ os.makedirs(os.path.dirname(dataset_path), exist_ok=True)
+
+ # Save the updated dataset
+ with open(dataset_path, "w") as json_file:
+ json.dump(dataset, json_file, indent=4)
+
+ print(
+ f"Datapoint saved to {dataset_path}. Dataset now contains {len(dataset)} samples."
+ )
diff --git a/examples/03_data_gen_examples/rlhf_dataset_gen.py b/examples/03_data_gen_examples/rlhf_dataset_gen.py
new file mode 100644
index 0000000..a8fcb19
--- /dev/null
+++ b/examples/03_data_gen_examples/rlhf_dataset_gen.py
@@ -0,0 +1,61 @@
+"""
+Example showing how to use the UnifiedRLHFUI class to generate
+behavior trees and datasets through a collaboration between LLM and human.
+"""
+from agent_control import RobotAgent, RobotEnvironment
+from data_grammar.rlhf_generation.rlhf_unified import UnifiedRLHFUI
+
+# Step 1: Define grammar rules for behavior tree validation
+grammar_rules = {
+ "B": [["b", ["SEL"]], ["b", ["SEQ"]]],
+ "SEL": [["sel", ["SEQn", "As"]], ["sel", ["SEQn"]]],
+ "SEQn": [["SEQ", "SEQn"], ["SEQ"]],
+ "SEQ": [["seq", ["Pn", "A"]], ["seq", ["As", "Pn", "A"]]],
+ "b": ["BehaviorTree", ["children_nodes"]],
+ "sel": ["Selector", ["children_nodes"]],
+ "seq": ["Sequence", ["children_nodes"]],
+ "A": [["aa", "sa"], ["aa"], ["sa"]],
+ "As": [["aa"], ["sa"]],
+ "aa": ["ActuatorAction"],
+ "sa": ["StateAction"],
+ "Pn": [["p", "Pn"], ["p"], []],
+ "p": ["Condition"],
+}
+
+# Step 2: Configure simulation parameters
+config_kwargs = {
+ "radius": 40,
+ "duration": 500,
+ "movement_speed": 1.0,
+ "window_size": 800,
+ "visualise_chunks": True,
+}
+
+# Step 3: Configure environment parameters
+environment_kwargs = {
+ "n_agents": 10,
+ "n_parts": 10,
+ "task": "custom_task",
+ "headless": False,
+}
+
+# Step 4: Initialize the RLHF dataset generation UI
+# This creates a unified interface for interactive behavior tree generation
+ui = UnifiedRLHFUI(
+ dataset_path="./examples/03_data_gen_examples/output/dataset_path.json", # Output dataset file
+ dataset_size_goal=10, # Target number of datapoints
+ agent_class=RobotAgent, # Agent class for validation
+ grammar_rules=grammar_rules, # Grammar rules defined above
+ environment_class=RobotEnvironment, # Environment class for simulation
+ environment_kwargs=environment_kwargs, # Environment parameters
+ config_kwargs=config_kwargs, # Simulation configuration
+)
+
+# Step 5: Start the interactive RLHF dataset generation process
+# - Define tasks in natural language
+# - Generate behavior trees using LLM
+# - Validate trees against grammar and primitives
+# - Run simulations to test behavior
+# - Provide human feedback for iterative improvement
+# - Save successful datapoints to the dataset
+ui.run()
\ No newline at end of file
diff --git a/llm_interface/layer_LLM.py b/llm_interface/layer_LLM.py
index b5842d0..f16e277 100644
--- a/llm_interface/layer_LLM.py
+++ b/llm_interface/layer_LLM.py
@@ -21,7 +21,7 @@
except ImportError:
OLLAMA_AVAILABLE = False
-from dotenv import load_dotenv # type: ignore
+from dotenv import load_dotenv
from lancedb import table
from lancedb.pydantic import LanceModel
from openai import OpenAI
@@ -406,7 +406,7 @@ def _auto_import_gguf(self, gguf_path: str) -> str:
with open(modelfile_path, "r") as f:
modelfile_content = f.read()
- ollama.create(model=model_name, modelfile=modelfile_content)
+ ollama.create(model=model_name, modelfile=modelfile_content) # type: ignore[call-overload]
print(f"Successfully imported model as: {model_name}")
# Return with :latest suffix since that's how Ollama stores it
return f"{model_name}:latest"
@@ -414,7 +414,7 @@ def _auto_import_gguf(self, gguf_path: str) -> str:
# If API also fails, try one more approach - stream the creation
try:
print("Trying streaming create...")
- for response in ollama.create(
+ for response in ollama.create( # type: ignore[call-overload]
model=model_name, modelfile=modelfile_content, stream=True
):
if "status" in response: