diff --git a/examples/experimental/swe-agent/README.md b/examples/experimental/swe-agent/README.md index d71ad3ecd..53b780848 100644 --- a/examples/experimental/swe-agent/README.md +++ b/examples/experimental/swe-agent/README.md @@ -2,119 +2,69 @@ ## Introduction -This is an example for SWE-agent training. This example uses NVIDIA's Nemo-Gym as the Gym environment implement, SWE-Gym as the training data, and SWE-bench as the evaluation. +This is an example for SWE-agent training. This example uses NVIDIA's `mini-swe-agent` as the agent implementation, SWE-Gym as the training data, and SWE-bench as the evaluation harness. -This implementation of this example is partially in submodules below: -- Nemo-Gym: https://github.com/yueming-yuan/Gym/tree/miles-swe-agent +The implementation of this example is partially in the submodule below: - mini-swe-agent: https://github.com/yueming-yuan/nv-mini-swe-agent/tree/miles-swe-agent - ## Prepare environment + ### Update submodules ```bash -git submodule update --init --recursive . +git submodule update --init --recursive ``` -### Docker settings -```bash -# 1. create a docker network -docker network create swe-net -# 2. create environment docker -docker run -itd \ - --name swe_env \ - --shm-size 16g \ - -v /var/run/docker.sock:/var/run/docker.sock \ - -v /mnt/data:/data \ - -v /home/sglang-rl/:/workspace \ - --ipc=host \ - --ulimit nofile=65536:65536 \ - --ulimit memlock=-1 \ - --ulimit stack=67108864 \ - --network swe-net \ - ubuntu:latest \ - /bin/bash +### Docker settings +To run SWE-agent, the Miles container must be able to launch temporary Docker containers for each task environment. This requires mounting the Docker socket. -# 3. create miles docker +```bash docker run -itd \ --shm-size 32g \ --gpus all \ -v /mnt/data/cache/huggingface:/root/.cache/huggingface \ -v /mnt/data:/data \ - -v /home/sglang-rl/:/workspace \ + -v /var/run/docker.sock:/var/run/docker.sock \ --ipc=host \ --ulimit nofile=65536:65536 \ --ulimit memlock=-1 \ --ulimit stack=67108864 \ --privileged \ - --network swe-net \ + --network host \ --name miles_ \ radixark/miles:latest \ /bin/zsh - -# 4. install utils in environment docker -docker exec -it swe_env /bin/bash -apt update && apt install -y zsh curl git python3 python3-pip docker.io ``` -note: `-v /var/run/docker.sock:/var/run/docker.sock` is required for Docker-in-Docker SWE environment execution; use `--network swe-net` to enable communication between training & environment. +note: `-v /var/run/docker.sock:/var/run/docker.sock` is required for Docker-in-Docker SWE environment execution. ### Installation -In **environment docker**, install Gym -```bash -git clone https://github.com/yueming-yuan/Gym -cd Gym - -curl -LsSf https://astral.sh/uv/install.sh | sh -source $HOME/.local/bin/env -uv venv --python 3.12 && source .venv/bin/activate -uv sync --extra dev --group docs +Inside the **Miles docker**, install the Docker CLI and the SWE-Gym harness: -# configure env.yaml -echo "policy_base_url: https://api.openai.com/v1 -policy_api_key: your-openai-api-key -policy_model_name: gpt-4.1-2025-04-14 -default_host: 0.0.0.0" > env.yaml -``` -note: set host IP to `0.0.0.0` to enable communications between dockers. - -then set up for SWE-agent server: ```bash -cd responses_api_agents/mini_swe_agent -uv pip install -r requirements.txt -``` -Now you should be able to run the SWE-agent server. - -For **miles docker** setup, please follow the standard setup process. +# Install Docker CLI +apt update && apt install -y docker.io -## Preparing data -In **miles docker**, download **SWE-Gym** data from huggingface and convert it to Miles' prompt data format with this script. -``` -cd miles/examples/swe-agent -python download_and_process_data.py --input SWE-Gym/SWE-Gym --output /root/swe_train.jsonl +# Install SWE-Gym harness +pip install "swegym @ git+https://github.com/sdevare-nv/nv-SWE-Bench-Package.git@31e1cb8f0241da1707d00faa633c3d6ce1a8ba3b" ``` -## Running train -1. In environment docker, launch the agent server +## Preparing data +Download **SWE-Gym** data from Hugging Face and convert it to Miles' prompt data format: ```bash -cd Gym -source .venv/bin/activate -cd responses_api_agents/mini_swe_agent -./start_server.sh +cd examples/experimental/swe-agent +python3 download_and_process_data.py --input SWE-Gym/SWE-Gym --output /root/swe_train.jsonl ``` - -2. In miles docker, -(1) export `SWE_AGENT_GYM_URL` to be the port of the second server you started in Gym in environment docker, whose `server_type` is `responses_api_agents`. `swe_env` is the environment docker's name; replace it if you changed the name. -(minor TODO: modify the port selections to avoid setting this every time.) (2) launch the training. +## Running train +Launch the training directly from the Miles root directory: ```bash -export SWE_AGENT_GYM_URL="http://swe_env:" -bash examples/swe-agent/run-qwen3-4b-instruct.sh +bash examples/experimental/swe-agent/run-qwen3-4b-instruct.sh ``` - ## Troubleshooting -1. The first time of every SWE environment can be slow, and may need to wait before generation, because each SWE-Gym task has a specific docker, and `docker pull` takes time. -2. Sometimes the environment may also be slow at evaluation. The timeout of evaluation is 10 minutes by default. If the server is stuck at `[EVAL] Running eval`, you may need to wait for it. +1. **Slow Startup:** The first time a SWE environment is created, it may be slow because each SWE-Gym task requires a specific Docker image, and `docker pull` takes time. +2. **Evaluation Timeout:** Sometimes the environment may be slow during evaluation. The default timeout is 10 minutes. If logs show `[EVAL] Running eval`, please wait for completion. +3. **Ray Stability:** The rollout process uses background threads to ensure the Ray cluster remains responsive during long-running agent tasks. ## Metrics ``` @@ -124,7 +74,6 @@ agent/total_time_mean/max/min - Total time statistics agent/model_query_time_sum_mean - Avg total model time per rollout agent/env_execution_time_sum_mean - Avg total env time per rollout agent/eval_time_mean - Avg evaluation time -agent/overhead_time_mean - Avg overhead time agent/time_per_turn - Avg time per turn agent/model_query_time_avg - Avg model query time per turn agent/env_execution_time_avg - Avg env execution time per turn diff --git a/examples/experimental/swe-agent/generate_with_swe_agent.py b/examples/experimental/swe-agent/generate_with_swe_agent.py index b0dbbd612..f2006250a 100644 --- a/examples/experimental/swe-agent/generate_with_swe_agent.py +++ b/examples/experimental/swe-agent/generate_with_swe_agent.py @@ -1,14 +1,22 @@ +import asyncio import logging -import os +import time +import uuid from argparse import Namespace from collections.abc import Callable +from pathlib import Path from typing import Any +from minisweagent.agents.default import DefaultAgent + +from minisweagent.environments import DockerEnvironment +from minisweagent.models import get_model +from minisweagent.run.extra.swegym_runner import get_swegym_docker_image_name, run_eval + from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import DynamicFilterOutput from miles.rollout.sglang_rollout import GenerateState, eval_rollout from miles.utils.async_utils import run -from miles.utils.http_utils import post from miles.utils.types import Sample logger = logging.getLogger(__name__) @@ -56,63 +64,129 @@ def build_tokens_and_mask_from_messages( return all_tokens, loss_mask, response_text, response_length +def run_agent_sync_logic(model, env, problem_statement, sampling_params, metadata, instance_dir, run_id): + """ + Synchronous wrapper to run the agent and evaluation. + This is offloaded to a thread to prevent blocking the Ray actor. + """ + agent = DefaultAgent( + model=model, + env=env, + responses_create_params={"input": []}, + sampling_params=sampling_params, + step_limit=250, + collapse_limit=3, + ) + + # Execute the agent lifecycle + exit_status, result_patch, agent_metrics = agent.run(problem_statement) + + # Run evaluation + eval_start = time.time() + eval_report_full = run_eval( + instance=metadata, + env=env, + model_patch=result_patch, + instance_dir=instance_dir, + run_id=run_id, + ) + eval_time = time.time() - eval_start + + # metrics calculation + agent_metrics["eval_time"] = eval_time + total_time = agent_metrics.get("agent_run_time", 0) + eval_time + agent_metrics["total_time"] = total_time + agent_metrics["model_time_ratio"] = agent_metrics.get("model_query_time_sum", 0) / max(total_time, 1e-6) + agent_metrics["env_time_ratio"] = agent_metrics.get("env_execution_time_sum", 0) / max(total_time, 1e-6) + agent_metrics["eval_time_ratio"] = eval_time / max(total_time, 1e-6) + + return exit_status, agent.messages, eval_report_full, agent_metrics + + async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: """ Custom generation function for SWE-Agent integration. Orchestrates the interaction with the external Gym environment: - 1. Sends prompt/metadata to Gym. - 2. Receives execution trace (messages) and rewards. + 1. Directly initializes mini-swe-agent components. + 2. Runs agent logic in a background thread to maintain Ray cluster stability. 3. Formats data for Miles training format. Note: Performs in-place modification of `sample` for memory efficiency. """ - # Prepare request for Gym /run endpoint - request = { - "responses_create_params": { - "input": [], - }, - "sampling_params": sampling_params, - **sample.metadata, - "sglang_url": f"http://{args.sglang_router_ip}:{args.sglang_router_port}/v1", - } - - gym_url = os.getenv("SWE_AGENT_GYM_URL", "http://localhost:11000") - response = await post(f"{gym_url}/run", request) - - exit_status = response.get("info", {}).get("exit_status", "") - logger.debug(f"exit_status: {exit_status}, reward: {response.get('reward', 0.0)}") - - messages = response.get("messages", []) - - if len(messages) >= 2: - sample.prompt = messages[:2] - - state = GenerateState(args) - tokens, loss_mask, response_text, response_length = build_tokens_and_mask_from_messages( - messages=messages, - tokenizer=state.tokenizer, - ) + instance_id = sample.metadata.get("instance_id") + subset = sample.metadata.get("subset", "gym") + problem_statement = sample.metadata.get("problem_statement") + + # Model configuration + model_name = f"sglang/{Path(args.hf_checkpoint).name}" + sglang_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/v1" + + model_config = {"model_name": model_name, "model_kwargs": {"base_url": sglang_url, "api_key": "dummy"}} + model = get_model(model_name, config=model_config) + + # Environment configuration + image_name = get_swegym_docker_image_name(sample.metadata, subset) + output_dir = Path("results") / subset / model_name + instance_dir = output_dir / instance_id + instance_dir.mkdir(parents=True, exist_ok=True) + run_id = f"{int(time.time())}_{str(uuid.uuid4())[:8]}" + + env = None + try: + # Initialize the Docker environment + env = DockerEnvironment( + image=image_name, + instance_id=instance_id, + step_timeout=600, + eval_timeout=600, + ) + + # Off-Load to Thread + exit_status, messages, eval_report_full, agent_metrics = await asyncio.to_thread( + run_agent_sync_logic, model, env, problem_statement, sampling_params, sample.metadata, instance_dir, run_id + ) + + # Extract reward from evaluation report + report_data = eval_report_full.get("eval_report", {}).get(instance_id, {}) + resolved = report_data.get("resolved", False) + reward = 1.0 if resolved else 0.0 + + if len(messages) >= 2: + sample.prompt = messages[:2] + + state = GenerateState(args) + tokens, loss_mask, response_text, response_length = build_tokens_and_mask_from_messages( + messages=messages, + tokenizer=state.tokenizer, + ) + + sample.rollout_log_probs = None # TODO + sample.tokens = tokens + sample.loss_mask = loss_mask + sample.response = response_text + sample.response_length = response_length + sample.reward = reward + sample.metadata["reward"] = reward + sample.metadata["eval_report"] = eval_report_full + sample.metadata["messages"] = messages + sample.metadata["agent_metrics"] = agent_metrics + + if exit_status == "Submitted": + sample.status = Sample.Status.COMPLETED + elif exit_status in ("RolloutTruncated", "LimitsExceeded", "CollapseContinued"): + sample.status = Sample.Status.TRUNCATED + else: + sample.status = Sample.Status.ABORTED + sample.reward = 0.0 - sample.rollout_log_probs = None # TODO - sample.tokens = tokens - sample.loss_mask = loss_mask - sample.response = response_text - sample.response_length = response_length - sample.metadata["reward"] = response.get("reward", 0.0) - sample.metadata["eval_report"] = response.get("metadata", {}) - sample.metadata["messages"] = messages - - agent_metrics = response.get("info", {}).get("agent_metrics", {}) - sample.metadata["agent_metrics"] = agent_metrics - - if exit_status == "Submitted": - sample.status = Sample.Status.COMPLETED - elif exit_status in ("RolloutTruncated", "LimitsExceeded", "CollapseContinued"): - sample.status = Sample.Status.TRUNCATED - else: + except Exception as e: + logger.error(f"Error processing instance {instance_id}: {e}", exc_info=True) sample.status = Sample.Status.ABORTED sample.reward = 0.0 + finally: + if env: + env.cleanup() return sample diff --git a/examples/experimental/swe-agent/run-qwen3-4b-instruct.sh b/examples/experimental/swe-agent/run-qwen3-4b-instruct.sh index d9c9dd953..55583df97 100755 --- a/examples/experimental/swe-agent/run-qwen3-4b-instruct.sh +++ b/examples/experimental/swe-agent/run-qwen3-4b-instruct.sh @@ -25,9 +25,7 @@ echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -export SWE_AGENT_GYM_URL="${SWE_AGENT_GYM_URL:-http://swe_env:11000}" - -source "${SCRIPT_DIR}/../../scripts/models/qwen3-4B-Instruct-2507.sh" +source "${SCRIPT_DIR}/../../../scripts/models/qwen3-4B-Instruct-2507.sh" CKPT_ARGS=( --hf-checkpoint /root/qwen3-4B-Instruct-2507 @@ -135,15 +133,13 @@ ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-s RUNTIME_ENV_JSON="{ \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}:/root/miles\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"SWE_AGENT_GYM_URL\": \"${SWE_AGENT_GYM_URL}\" + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}:/root/miles:${SCRIPT_DIR}/mini-swe-agent/src\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" } }" # \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", echo "Launching training..." -echo " SWE Agent URL: ${SWE_AGENT_GYM_URL}" ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \