Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion agent/multi_agent_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from PIL import Image

from browser_env import Action, Trajectory
from browser_env import Action, Trajectory, ActionTypes, create_stop_action
from browser_env.actions import is_equivalent
from browser_env.helper_functions import get_action_description
from llms import lm_config

Expand Down Expand Up @@ -81,6 +82,10 @@ def __init__(self, lm_config: lm_config.LMConfig,
self.actions: List[Action] = []
self.reflections: List[Dict[str, Any]] = []

# Early stop thresholds (can be provided via memory_config or use defaults)
self.parsing_failure_th = memory_config.get("parsing_failure_th", 3)
self.repeating_action_th = memory_config.get("repeating_action_th", 5)

# Meta data for action history tracking (required by DirectPromptConstructor)
# Initialize with "None" as the first action, matching run.py implementation
self.meta_data: Dict[str, Any] = {"action_history": ["None"]}
Expand Down Expand Up @@ -214,6 +219,19 @@ def __init__(self, url: str = ""):
# Main execution loop
while True:
try:
# Early-stop check (mirrors run.py behavior)
early_flag, early_info = self._early_stop(self.max_steps, {
"parsing_failure": self.parsing_failure_th,
"repeating_action": self.repeating_action_th,
})
if early_flag:
# Append a stop action to the trajectory and end execution
stop_action = create_stop_action(f"Early stop: {early_info}")
self.actions.append(stop_action)
self.trajectory.append(stop_action)
# Log and break
self.log_agent_response("coordinator", len(self.actions), {"early_stop": early_info})
break
# Check if we should continue
context_summary = self._get_current_context_summary()
continuation_decision = self.workflow_manager.should_continue_execution(context_summary)
Expand Down Expand Up @@ -289,6 +307,59 @@ def __init__(self, url: str = ""):
"final_context": final_context_summary,
}

def _early_stop(self, max_steps: int, thresholds: Dict[str, int]) -> tuple[bool, str]:
"""Check whether need to stop early, similar to run.py's early_stop.

Returns (flag, reason_str)
"""
# reach the max step
try:
num_steps = (len(self.trajectory) - 1) / 2
except Exception:
num_steps = 0

if num_steps >= max_steps:
return True, f"Reach max steps {max_steps}"

# Parsing failure check
k = thresholds.get("parsing_failure", 3)
if k > 0:
last_k_actions = self.trajectory[1::2][-k:]
if len(last_k_actions) >= k:
if all([
(isinstance(action, dict) and action.get("action_type") == ActionTypes.NONE)
for action in last_k_actions
]):
return True, f"Failed to parse actions for {k} times"

# Repeating action check
k = thresholds.get("repeating_action", 5)
action_seq = self.trajectory[1::2]
if len(action_seq) == 0:
return False, ""

last_action = action_seq[-1]
try:
last_action_type = last_action.get("action_type")
except Exception:
last_action_type = None

if last_action_type != ActionTypes.TYPE:
if len(action_seq) >= k:
last_k_actions = action_seq[-k:]
if all([is_equivalent(action, last_action) for action in last_k_actions]):
return True, f"Same action for {k} times"
else:
# typing action: check frequency across full sequence
try:
count_same = sum([1 for action in action_seq if is_equivalent(action, last_action)])
if count_same >= k:
return True, f"Same typing action for {k} times"
except Exception:
pass

return False, ""

def _register_agents(self) -> None:
"""Register all agents with the communication hub."""
self.communication_hub.register_agent(
Expand Down Expand Up @@ -476,6 +547,29 @@ def _execute_coordination_cycle(self, images: Optional[List[Image.Image]] = None
self.actions.append(executed_action)

# Execute action in browser environment if available
# If the executed action is a STOP action, mirror run.py behavior:
# append the action to the trajectory and terminate execution cycle
if executed_action.get("action_type") == ActionTypes.STOP:
print(f"🔍 Executing action in browser: {executed_action.get('action_type', 'UNKNOWN')}")
# update action history and trajectory so evaluators see Action as last element
action_str = f"STOP: {executed_action.get('answer', '')}"
try:
self.meta_data["action_history"].append(action_str)
except Exception:
self.meta_data["action_history"] = ["None", action_str]

# append the stop action to trajectory and return termination signal
self.trajectory.append(executed_action)
return {
"should_terminate": True,
"step_number": step_number,
"context_result": context_result,
"planning_result": planning_result,
"execution_result": execution_result,
"reflection_result": {},
"new_observation": None,
}

if self.browser_env is not None:
try:
print(f"🔍 Executing action in browser: {executed_action.get('action_type', 'UNKNOWN')}")
Expand Down
2 changes: 1 addition & 1 deletion evaluation_harness/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_last_action(trajectory: Trajectory) -> Action:
raise ValueError(
"The last element of trajectory should be an action, add a fake stop action if needed"
)

print(f"Last action: {last_action}")
return last_action # type: ignore[return-value]

@staticmethod
Expand Down
104 changes: 52 additions & 52 deletions evaluation_harness/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,62 @@
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from transformers import (
Blip2ForConditionalGeneration,
Blip2Processor,
)

from openai import OpenAI
import base64
import io
import os
from browser_env.utils import pil_to_b64

def get_captioning_fn(
device, dtype, model_name: str = "Salesforce/blip2-flan-t5-xl"
) -> callable:
if "blip2" in model_name:
captioning_processor = Blip2Processor.from_pretrained(model_name)
captioning_model = Blip2ForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype
)
else:
raise NotImplementedError(
"Only BLIP-2 models are currently supported"
)
captioning_model.to(device)

def caption_images(
images: List[Image.Image],
prompt: List[str] = None,
max_new_tokens: int = 32,
) -> List[str]:
if prompt is None:
# Perform VQA
inputs = captioning_processor(
images=images, return_tensors="pt"
).to(device, dtype)
generated_ids = captioning_model.generate(
**inputs, max_new_tokens=max_new_tokens
)
captions = captioning_processor.batch_decode(
generated_ids, skip_special_tokens=True
)
else:
# Regular captioning. Prompt is a list of strings, one for each image
assert len(images) == len(
prompt
), "Number of images and prompts must match, got {} and {}".format(
len(images), len(prompt)
)
inputs = captioning_processor(
images=images, text=prompt, return_tensors="pt"
).to(device, dtype)
generated_ids = captioning_model.generate(
**inputs, max_new_tokens=max_new_tokens
)
captions = captioning_processor.batch_decode(
generated_ids, skip_special_tokens=True
)
def get_captioning_fn(
device, dtype, model_name: str = "qwen3-vl-plus"
) -> callable:
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url=os.environ.get("OPENAI_BASE_URL"))

def caption_images(
images: List[Image.Image],
prompt: List[str] = None,
max_new_tokens: int = 300, # Increased default for API
) -> List[str]:
captions = []

# Prepare prompts
if prompt is None:
prompts = ["Describe this image in detail."] * len(images)
else:
prompts = prompt

assert len(images) == len(prompts), "Number of images and prompts must match"

return captions
for img, p in zip(images, prompts):
try:
response = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": p},
{
"type": "image_url",
"image_url": {
"url": pil_to_b64(img)
},
},
],
}
],
max_tokens=max_new_tokens,
)
captions.append(response.choices[0].message.content)
except Exception as e:
print(f"Error calling API: {e}")
captions.append("")

return captions

return caption_images
return caption_images


def get_image_ssim(imageA, imageB):
Expand All @@ -81,4 +81,4 @@ def get_image_ssim(imageA, imageB):

# Compute the Structural Similarity Index (SSIM) between the two images
score, _ = ssim(grayA, grayB, full=True)
return score
return score
Loading