diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml
index f00faa5a6..1e26168d6 100644
--- a/.github/workflows/pr-test.yml
+++ b/.github/workflows/pr-test.yml
@@ -25,6 +25,46 @@ concurrency:
jobs:
+ fast:
+ if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request)
+ runs-on: self-hosted
+ container:
+ image: radixark/miles:latest
+ options: >
+ --gpus all
+ --ipc=host
+ --shm-size=16g
+ --ulimit memlock=-1
+ --ulimit stack=67108864
+ --memory=0
+ --memory-swap=0
+ -v /mnt/nvme0n1/miles_ci:/data/miles_ci
+ -v /mnt/nvme0n1/miles_ci/models:/root/models
+ -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets
+ strategy:
+ fail-fast: false
+ matrix:
+ info: [{"num_gpus": 0, "test_file": "fast"}]
+ defaults:
+ run:
+ working-directory: ${{ github.workspace }}
+ env:
+ GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }}
+ WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
+ MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }}
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Install
+ shell: bash
+ run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages
+
+ - name: Execute
+ shell: bash
+ run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }}
+
e2e-test-short:
if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short'))
runs-on: self-hosted
@@ -38,10 +78,6 @@ jobs:
--ulimit stack=67108864
--memory=0
--memory-swap=0
- -e http_proxy=$http_proxy
- -e https_proxy=$https_proxy
- -e HTTP_PROXY=$HTTP_PROXY
- -e HTTPS_PROXY=$HTTPS_PROXY
-v /mnt/nvme0n1/miles_ci:/data/miles_ci
-v /mnt/nvme0n1/miles_ci/models:/root/models
-v /mnt/nvme0n1/miles_ci/datasets:/root/datasets
@@ -82,10 +118,6 @@ jobs:
--ulimit stack=67108864
--memory=0
--memory-swap=0
- -e http_proxy=$http_proxy
- -e https_proxy=$https_proxy
- -e HTTP_PROXY=$HTTP_PROXY
- -e HTTPS_PROXY=$HTTPS_PROXY
-v /mnt/nvme0n1/miles_ci:/data/miles_ci
-v /mnt/nvme0n1/miles_ci/models:/root/models
-v /mnt/nvme0n1/miles_ci/datasets:/root/datasets
@@ -126,10 +158,6 @@ jobs:
--ulimit stack=67108864
--memory=0
--memory-swap=0
- -e http_proxy=$http_proxy
- -e https_proxy=$https_proxy
- -e HTTP_PROXY=$HTTP_PROXY
- -e HTTPS_PROXY=$HTTPS_PROXY
-v /mnt/nvme0n1/miles_ci:/data/miles_ci
-v /mnt/nvme0n1/miles_ci/models:/root/models
-v /mnt/nvme0n1/miles_ci/datasets:/root/datasets
@@ -170,10 +198,6 @@ jobs:
--ulimit stack=67108864
--memory=0
--memory-swap=0
- -e http_proxy=$http_proxy
- -e https_proxy=$https_proxy
- -e HTTP_PROXY=$HTTP_PROXY
- -e HTTPS_PROXY=$HTTPS_PROXY
-v /mnt/nvme0n1/miles_ci:/data/miles_ci
-v /mnt/nvme0n1/miles_ci/models:/root/models
-v /mnt/nvme0n1/miles_ci/datasets:/root/datasets
@@ -214,10 +238,6 @@ jobs:
--ulimit stack=67108864
--memory=0
--memory-swap=0
- -e http_proxy=$http_proxy
- -e https_proxy=$https_proxy
- -e HTTP_PROXY=$HTTP_PROXY
- -e HTTPS_PROXY=$HTTPS_PROXY
-v /mnt/nvme0n1/miles_ci:/data/miles_ci
-v /mnt/nvme0n1/miles_ci/models:/root/models
-v /mnt/nvme0n1/miles_ci/datasets:/root/datasets
@@ -258,10 +278,6 @@ jobs:
--ulimit stack=67108864
--memory=0
--memory-swap=0
- -e http_proxy=$http_proxy
- -e https_proxy=$https_proxy
- -e HTTP_PROXY=$HTTP_PROXY
- -e HTTPS_PROXY=$HTTPS_PROXY
-v /mnt/nvme0n1/miles_ci:/data/miles_ci
-v /mnt/nvme0n1/miles_ci/models:/root/models
-v /mnt/nvme0n1/miles_ci/datasets:/root/datasets
@@ -302,10 +318,6 @@ jobs:
--ulimit stack=67108864
--memory=0
--memory-swap=0
- -e http_proxy=$http_proxy
- -e https_proxy=$https_proxy
- -e HTTP_PROXY=$HTTP_PROXY
- -e HTTPS_PROXY=$HTTPS_PROXY
-v /mnt/nvme0n1/miles_ci:/data/miles_ci
-v /mnt/nvme0n1/miles_ci/models:/root/models
-v /mnt/nvme0n1/miles_ci/datasets:/root/datasets
diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2
index 25bb2bce2..055dfee63 100644
--- a/.github/workflows/pr-test.yml.j2
+++ b/.github/workflows/pr-test.yml.j2
@@ -1,4 +1,10 @@
<% set jobs = {
+ 'fast': {
+ 'test_executor': 'pytest',
+ 'tests': [
+ {'test_file': 'fast', 'num_gpus': 0},
+ ],
+ },
'e2e-test-short': {
'label': 'run-ci-short',
'tests': [
@@ -95,7 +101,7 @@ concurrency:
jobs:
<% for job_name, config in jobs.items() %>
<< job_name >>:
- if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, '<< config.label >>'))
+ if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request<% if config.label %> && contains(github.event.pull_request.labels.*.name, '<< config.label >>')<% endif %>)
runs-on: self-hosted
container:
image: << config.image if config.image else 'radixark/miles:latest' >>
@@ -107,10 +113,6 @@ jobs:
--ulimit stack=67108864
--memory=0
--memory-swap=0
- -e http_proxy=$http_proxy
- -e https_proxy=$https_proxy
- -e HTTP_PROXY=$HTTP_PROXY
- -e HTTPS_PROXY=$HTTPS_PROXY
-v /mnt/nvme0n1/miles_ci:/data/miles_ci
-v /mnt/nvme0n1/miles_ci/models:/root/models
-v /mnt/nvme0n1/miles_ci/datasets:/root/datasets
@@ -136,5 +138,5 @@ jobs:
- name: Execute
shell: bash
- run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }}
+ run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- << config.test_executor | default('python') >> tests/${{ matrix.info.test_file }}
<% endfor %>
\ No newline at end of file
diff --git a/examples/openai_format/__init__.py b/examples/openai_format/__init__.py
new file mode 100644
index 000000000..30436bcc4
--- /dev/null
+++ b/examples/openai_format/__init__.py
@@ -0,0 +1 @@
+"""OpenAI format examples."""
diff --git a/examples/openai_format/dapo_math.py b/examples/openai_format/dapo_math.py
new file mode 100644
index 000000000..6fe69433e
--- /dev/null
+++ b/examples/openai_format/dapo_math.py
@@ -0,0 +1,57 @@
+"""
+DAPO math OpenAI format example for token in/out verification.
+"""
+
+import argparse
+from typing import Any
+
+from openai import AsyncOpenAI
+
+from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput
+from miles.rollout.generate_utils.openai_endpoint_utils import (
+ OpenAIEndpointTracer,
+ compute_samples_from_openai_records,
+)
+from miles.rollout.generate_utils.sample_utils import merge_samples
+
+_DAPO_MATH_SYSTEM_PROMPT = (
+ "Solve the math problem and return the final answer as \\boxed{integer}. "
+ "Keep the reasoning concise and finish with the boxed answer."
+)
+
+
+async def generate(input: GenerateFnInput) -> GenerateFnOutput:
+ tracer = await OpenAIEndpointTracer.create(input.args)
+ messages = _normalize_prompt(input.sample.prompt)
+ await _run_single_turn_openai(base_url=tracer.base_url, messages=messages)
+
+ records = await tracer.collect_records()
+ samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer)
+ if not input.args.generate_multi_samples:
+ samples = merge_samples(samples, input.state.tokenizer)
+ return GenerateFnOutput(samples=samples)
+
+
+def _add_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument("--generate-multi-samples", action="store_true")
+
+
+generate.add_arguments = _add_arguments
+
+
+def build_dapo_math_messages(question: str) -> list[dict[str, str]]:
+ return [
+ {"role": "system", "content": _DAPO_MATH_SYSTEM_PROMPT},
+ {"role": "user", "content": question},
+ ]
+
+
+def _normalize_prompt(prompt: Any) -> list[dict[str, Any]]:
+ if isinstance(prompt, list):
+ return prompt
+ return build_dapo_math_messages(prompt)
+
+
+async def _run_single_turn_openai(base_url: str, messages: list[dict[str, Any]]) -> None:
+ client = AsyncOpenAI(base_url=base_url, api_key="empty")
+ await client.chat.completions.create(model="default", messages=messages)
diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py
index 79c6649be..27211845d 100644
--- a/miles/ray/rollout.py
+++ b/miles/ray/rollout.py
@@ -13,8 +13,15 @@
from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from miles.backends.sglang_utils.sglang_engine import SGLangEngine
-from miles.rollout.base_types import call_rollout_fn
+from miles.rollout.base_types import (
+ RolloutFnConstructorInput,
+ RolloutFnEvalInput,
+ RolloutFnTrainInput,
+ call_rollout_fn,
+)
+from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function
from miles.utils import tracking_utils
+from miles.utils.environ import enable_experimental_rollout_refactor
from miles.utils.health_monitor import RolloutHealthMonitor
from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client
from miles.utils.iter_utils import group_by
@@ -53,8 +60,14 @@ def __init__(self, args, pg):
data_source_cls = load_function(self.args.data_source_path)
self.data_source = data_source_cls(args)
- self.generate_rollout = load_function(self.args.rollout_function_path)
- self.eval_generate_rollout = load_function(self.args.eval_function_path)
+ self.use_experimental_refactor = enable_experimental_rollout_refactor()
+ if self.use_experimental_refactor:
+ input = RolloutFnConstructorInput(args=args, data_source=self.data_source)
+ self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path)
+ self.eval_generate_rollout = load_rollout_function(input, self.args.eval_function_path)
+ else:
+ self.generate_rollout = load_function(self.args.rollout_function_path)
+ self.eval_generate_rollout = load_function(self.args.eval_function_path)
self.custom_reward_post_process_func = None
if self.args.custom_reward_post_process_path is not None:
self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path)
@@ -142,7 +155,12 @@ def eval(self, rollout_id):
return
self.health_monitoring_resume()
- result = call_rollout_fn(self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True)
+ if self.use_experimental_refactor:
+ result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id))
+ else:
+ result = call_rollout_fn(
+ self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True
+ )
data = result.data
self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True)
metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics)
@@ -224,7 +242,12 @@ def _get_rollout_data(self, rollout_id):
)
metrics = None
else:
- data = call_rollout_fn(self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False)
+ if self.use_experimental_refactor:
+ data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id))
+ else:
+ data = call_rollout_fn(
+ self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False
+ )
metrics = data.metrics
data = data.samples
# flatten the data if it is a list of lists
diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py
index faa85c726..c2644e87f 100644
--- a/miles/rollout/base_types.py
+++ b/miles/rollout/base_types.py
@@ -1,22 +1,86 @@
+from __future__ import annotations
+
+from argparse import Namespace
from dataclasses import dataclass
-from typing import Any
+from typing import TYPE_CHECKING, Any
+from miles.rollout.data_source import DataSource
from miles.utils.types import Sample
+if TYPE_CHECKING:
+ from miles.rollout.inference_rollout.inference_rollout_common import GenerateState
+
+
+@dataclass(frozen=True)
+class RolloutFnConstructorInput:
+ args: Namespace
+ # TODO may refactor DataSource API
+ data_source: DataSource
+
+
+@dataclass(frozen=True)
+class RolloutFnBaseInput:
+ rollout_id: int
+
+ @property
+ def evaluation(self):
+ raise NotImplementedError
+
+
+# subclassing for different data in the future
+@dataclass(frozen=True)
+class RolloutFnTrainInput(RolloutFnBaseInput):
+ @property
+ def evaluation(self):
+ return False
+
+@dataclass(frozen=True)
+class RolloutFnEvalInput(RolloutFnBaseInput):
+ @property
+ def evaluation(self):
+ return True
+
+
+# TODO make it frozen
@dataclass
class RolloutFnTrainOutput:
samples: list[list[Sample]]
metrics: dict[str, Any] = None
+# TODO make it frozen
@dataclass
class RolloutFnEvalOutput:
data: dict[str, dict[str, Any]]
metrics: dict[str, Any] = None
+RolloutFnInput = RolloutFnTrainInput | RolloutFnEvalInput
+RolloutFnOutput = RolloutFnTrainOutput | RolloutFnEvalOutput
+
+
+@dataclass(frozen=True)
+class GenerateFnInput:
+ state: GenerateState
+ sample: Sample
+ sampling_params: dict[str, Any]
+ evaluation: bool
+
+ @property
+ def args(self) -> Namespace:
+ return self.state.args
+
+
+@dataclass(frozen=True)
+class GenerateFnOutput:
+ # One generate may lead to multiple samples, such as multi-agent, tree-like exploration, or
+ # multi-turn with removing thinking tokens.
+ samples: Sample | list[Sample]
+
+
def call_rollout_fn(fn, *args, evaluation: bool, **kwargs):
+ """Legacy rollout function call interface. Used when MILES_EXPERIMENTAL_ROLLOUT_REFACTOR is disabled."""
output = fn(*args, **kwargs, evaluation=evaluation)
# compatibility for legacy version
diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py
new file mode 100644
index 000000000..05223a654
--- /dev/null
+++ b/miles/rollout/generate_hub/agentic_tool_call.py
@@ -0,0 +1,85 @@
+"""
+Simple agentic demo with tool calling.
+"""
+
+import argparse
+from copy import deepcopy
+from typing import Any
+
+from openai import AsyncOpenAI
+
+from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput
+from miles.rollout.generate_utils.openai_endpoint_utils import (
+ OpenAIEndpointTracer,
+ compute_samples_from_openai_records,
+)
+from miles.rollout.generate_utils.sample_utils import merge_samples
+from miles.rollout.generate_utils.tool_call_utils import execute_tool_calls
+from miles.utils.misc import load_function
+
+
+async def generate(input: GenerateFnInput) -> GenerateFnOutput:
+ tracer = await OpenAIEndpointTracer.create(input.args)
+
+ await _run_blackbox_tool_call_agent(
+ base_url=tracer.base_url,
+ prompt=input.sample.prompt,
+ max_turns=input.args.generate_max_turns,
+ tool_specs_path=input.args.generate_tool_specs_path,
+ execute_tool_function_path=input.args.generate_execute_tool_function_path,
+ )
+
+ records = await tracer.collect_records()
+ samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer)
+ if not input.args.generate_multi_samples:
+ samples = merge_samples(samples, input.state.tokenizer)
+ return GenerateFnOutput(samples=samples)
+
+
+def _add_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument("--generate-max-turns", type=int, default=16)
+ parser.add_argument("--generate-tool-specs-path", type=str)
+ parser.add_argument("--generate-execute-tool-function-path", type=str)
+ parser.add_argument("--generate-multi-samples", action="store_true")
+
+
+generate.add_arguments = _add_arguments
+
+
+async def _run_blackbox_tool_call_agent(
+ base_url: str,
+ prompt: list[dict[str, Any]],
+ max_turns: int,
+ tool_specs_path: str,
+ execute_tool_function_path: str,
+):
+ """
+ Imagine this is a black-box agent, e.g. SWE-agent, which does arbitrarily complex work,
+ only understands OpenAI compatible API, and never understands Miles or the Sample data structure.
+ """
+
+ # ----------------------- Setup -------------------------
+
+ client = AsyncOpenAI(base_url=base_url, api_key="empty")
+ execute_tool_function = load_function(execute_tool_function_path)
+ tool_specs = load_function(tool_specs_path)
+
+ # ----------------------- Initial prompts -------------------------
+
+ messages = deepcopy(prompt)
+
+ for _turn in range(max_turns):
+ # ----------------------- Call inference endpoint -------------------------
+
+ response = await client.chat.completions.create(model="default", messages=messages, tools=tool_specs)
+
+ choice = response.choices[0]
+ messages.append(choice.message.model_dump())
+
+ if choice.finish_reason in ("stop", "length"):
+ break
+
+ # ----------------------- Execute tools -------------------------
+
+ if x := choice.message.tool_calls:
+ messages += await execute_tool_calls(x, execute_tool_function)
diff --git a/miles/rollout/generate_hub/multi_turn.py b/miles/rollout/generate_hub/multi_turn.py
new file mode 100644
index 000000000..97814ecb3
--- /dev/null
+++ b/miles/rollout/generate_hub/multi_turn.py
@@ -0,0 +1,88 @@
+"""
+Simple multi-turn generation with tool calling.
+"""
+
+import argparse
+from copy import deepcopy
+
+from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput
+from miles.rollout.generate_utils.generate_endpoint_utils import (
+ compute_prompt_ids_from_sample,
+ compute_request_payload,
+ update_sample_from_response,
+)
+from miles.rollout.generate_utils.tool_call_utils import (
+ create_tool_call_parser,
+ execute_tool_calls,
+ update_sample_with_tool_responses,
+)
+from miles.utils.http_utils import post
+from miles.utils.misc import load_function
+
+
+async def generate(input: GenerateFnInput) -> GenerateFnOutput:
+ # ----------------------- Setup -------------------------
+
+ args = input.args
+ sample = deepcopy(input.sample)
+ tokenizer = input.state.tokenizer
+ assert not args.partial_rollout, "Partial rollout is not supported"
+
+ url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
+
+ execute_tool_function = load_function(args.generate_execute_tool_function_path)
+
+ tool_specs = load_function(args.generate_tool_specs_path)
+ tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser)
+
+ multi_samples = []
+
+ # ----------------------- Initial prompts -------------------------
+
+ prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs)
+
+ sample.tokens = prompt_tokens_ids.copy()
+
+ for _turn in range(args.generate_max_turns):
+ # ----------------------- Call inference endpoint -------------------------
+
+ payload, halt_status = compute_request_payload(args, sample.tokens, input.sampling_params)
+ if payload is None:
+ sample.status = halt_status
+ if args.generate_multi_samples and multi_samples:
+ multi_samples[-1].status = halt_status
+ break
+
+ if args.generate_multi_samples:
+ sample = deepcopy(input.sample)
+
+ output = await post(url, payload)
+ await update_sample_from_response(args, sample, payload=payload, output=output, update_loss_mask=True)
+
+ if args.generate_multi_samples:
+ multi_samples.append(deepcopy(sample))
+
+ if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"):
+ break
+
+ # ----------------------- Execute tools -------------------------
+
+ _, tool_calls = tool_call_parser.parse_non_stream(output["text"])
+ if len(tool_calls) == 0:
+ break
+
+ tool_messages = await execute_tool_calls(tool_calls, execute_tool_function)
+ update_sample_with_tool_responses(sample, tool_messages, tokenizer=tokenizer)
+
+ return GenerateFnOutput(samples=multi_samples if args.generate_multi_samples else sample)
+
+
+def _add_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument("--generate-max-turns", type=int, default=16)
+ parser.add_argument("--generate-tool-specs-path", type=str)
+ parser.add_argument("--generate-tool-call-parser", type=str)
+ parser.add_argument("--generate-execute-tool-function-path", type=str)
+ parser.add_argument("--generate-multi-samples", action="store_true")
+
+
+generate.add_arguments = _add_arguments
diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py
new file mode 100644
index 000000000..5c0a15b5b
--- /dev/null
+++ b/miles/rollout/generate_hub/single_turn.py
@@ -0,0 +1,46 @@
+"""
+Simple single-turn generation.
+"""
+
+from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput
+from miles.rollout.generate_utils.generate_endpoint_utils import (
+ compute_prompt_ids_from_sample,
+ compute_request_payload,
+ update_sample_from_response,
+)
+from miles.utils.http_utils import post
+from miles.utils.types import Sample
+
+
+async def generate(input: GenerateFnInput) -> GenerateFnOutput:
+ args = input.args
+ sample = input.sample
+ sampling_params = input.sampling_params
+ assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}"
+ url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
+
+ prompt_ids = compute_prompt_ids_from_sample(input.state, sample)
+
+ # Handle Partial Rollout resuming
+ if len(sample.response) > 0:
+ input_ids = sample.tokens
+ sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids)
+
+ assert sampling_params["max_new_tokens"] >= 0
+ if sampling_params["max_new_tokens"] == 0:
+ sample.status = Sample.Status.TRUNCATED
+ return GenerateFnOutput(samples=sample)
+ else:
+ input_ids = prompt_ids
+
+ payload, halt_status = compute_request_payload(
+ args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs
+ )
+ if payload is None:
+ sample.status = halt_status
+ return GenerateFnOutput(samples=sample)
+
+ output = await post(url, payload)
+ await update_sample_from_response(args, sample, payload=payload, output=output)
+
+ return GenerateFnOutput(samples=sample)
diff --git a/miles/rollout/generate_utils/__init__.py b/miles/rollout/generate_utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/miles/rollout/generate_utils/generate_endpoint_utils.py b/miles/rollout/generate_utils/generate_endpoint_utils.py
new file mode 100644
index 000000000..a91d71f1d
--- /dev/null
+++ b/miles/rollout/generate_utils/generate_endpoint_utils.py
@@ -0,0 +1,112 @@
+"""
+Utils to integrate SGLang's `/generate` endpoint with RL things like Sample.
+"""
+
+from copy import deepcopy
+from typing import Any
+
+import numpy as np
+import pybase64
+
+from miles.utils.processing_utils import encode_image_for_rollout_engine
+from miles.utils.types import Sample
+
+
+# Make this an isolated function because users may want to compute their own
+def compute_prompt_ids_from_sample(state, sample, tools=None):
+ prompt = sample.prompt
+
+ if state.processor:
+ processor_output = state.processor(text=prompt, **sample.multimodal_inputs)
+ prompt_ids = processor_output["input_ids"][0]
+
+ # TODO shall we move it to other places? then can make this function immutable
+ sample.multimodal_train_inputs = {
+ k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]
+ } or None
+
+ return prompt_ids
+ else:
+ if not isinstance(prompt, str):
+ prompt = state.tokenizer.apply_chat_template(
+ prompt, tokenize=False, add_generation_prompt=True, tools=tools
+ )
+
+ return state.tokenizer.encode(prompt, add_special_tokens=False)
+
+
+def compute_request_payload(
+ args,
+ input_ids: list[int],
+ sampling_params: dict,
+ multimodal_inputs: dict | None = None,
+) -> tuple[dict[str, Any] | None, Sample.Status | None]:
+ sampling_params = deepcopy(sampling_params)
+ max_new_tokens = sampling_params.pop("max_new_tokens", args.rollout_max_response_len)
+ if x := args.rollout_max_context_len:
+ max_new_tokens = min(max_new_tokens, x - len(input_ids))
+ if max_new_tokens <= 0:
+ return None, Sample.Status.TRUNCATED
+
+ payload = {
+ "input_ids": input_ids,
+ "sampling_params": {**sampling_params, "max_new_tokens": max_new_tokens},
+ "return_logprob": True,
+ "return_routed_experts": args.use_rollout_routing_replay,
+ }
+ if image_data := (multimodal_inputs or {}).get("images"):
+ payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data]
+
+ return payload, None
+
+
+async def update_sample_from_response(
+ args, sample: Sample, payload: dict, output: dict, update_loss_mask: bool = False
+):
+ # Initialize sample.tokens for the first turn
+ if (len(sample.response) == 0) and not sample.tokens:
+ sample.tokens = payload["input_ids"]
+
+ if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths:
+ from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree
+
+ # TODO may rename to match
+ await postprocess_sample_with_radix_tree(args, sample, output)
+
+ assert not update_loss_mask, "This code branch has not implemented update_loss_mask"
+ else:
+ if x := output["meta_info"].get("output_token_logprobs"):
+ new_response_tokens = [item[1] for item in x]
+ new_response_log_probs = [item[0] for item in x]
+ else:
+ new_response_tokens, new_response_log_probs = [], []
+
+ # Update sample with tokens directly - avoiding re-tokenization
+ sample.tokens = sample.tokens + new_response_tokens
+ sample.response_length += len(new_response_tokens)
+ sample.response += output["text"]
+
+ if sample.rollout_log_probs is None:
+ sample.rollout_log_probs = []
+ sample.rollout_log_probs += new_response_log_probs
+
+ if update_loss_mask:
+ if sample.loss_mask is None:
+ sample.loss_mask = []
+ sample.loss_mask += [1] * len(new_response_tokens)
+
+ # TODO handle multi-turn cases (may need concat instead of assignment)
+ sample.rollout_routed_experts = _get_rollout_routed_experts_from_response(args, sample, output)
+
+ # TODO may unify (currently there are both methods inside Sample and separate functions)
+ sample.update_from_meta_info(args, output["meta_info"])
+
+
+def _get_rollout_routed_experts_from_response(args, sample, output):
+ info = output["meta_info"].get("routed_experts")
+ if info is None:
+ return None
+
+ x = np.frombuffer(pybase64.b64decode(info.encode("ascii")), dtype=np.int32)
+ x = x.reshape(len(sample.tokens) - 1, args.num_layers, args.moe_router_topk)
+ return x
diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py
new file mode 100644
index 000000000..a7a3a3e4a
--- /dev/null
+++ b/miles/rollout/generate_utils/openai_endpoint_utils.py
@@ -0,0 +1,69 @@
+"""
+Utilities for the OpenAI endpoint
+"""
+
+import logging
+from argparse import Namespace
+from copy import deepcopy
+
+from miles.router.session.sessions import GetSessionResponse, SessionRecord
+from miles.utils.http_utils import post
+from miles.utils.types import Sample
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIEndpointTracer:
+ def __init__(self, router_url: str, session_id: str):
+ self.router_url = router_url
+ self.session_id = session_id
+ self.base_url = f"{router_url}/sessions/{session_id}/v1"
+
+ @staticmethod
+ async def create(args: Namespace):
+ router_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}"
+ session_id = (await post(f"{router_url}/sessions", {}))["session_id"]
+ return OpenAIEndpointTracer(router_url=router_url, session_id=session_id)
+
+ async def collect_records(self) -> list[SessionRecord]:
+ response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="get")
+ response = GetSessionResponse.model_validate(response)
+ records = response.session_records
+ if records is None and isinstance(response.records, list):
+ records = response.records
+
+ try:
+ await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="delete")
+ except Exception as e:
+ logger.warning(f"Failed to delete session {self.session_id} after collecting records: {e}")
+
+ return records or []
+
+
+def compute_samples_from_openai_records(input_sample: Sample, records: list[SessionRecord], tokenizer) -> list[Sample]:
+ return [_compute_sample_from_openai_record(input_sample, record, tokenizer) for record in records]
+
+
+def _compute_sample_from_openai_record(input_sample: Sample, record: SessionRecord, tokenizer) -> Sample:
+ # TODO may refine after @guapisolo's implementation
+ choice = record.response["choices"][0]
+ output_token_ids = [item["token_id"] for item in choice["logprobs"]["content"]]
+ output_log_probs = [item["logprob"] for item in choice["logprobs"]["content"]]
+
+ sample = deepcopy(input_sample)
+ sample.tokens = record.request["input_ids"] + output_token_ids
+ sample.rollout_log_probs = output_log_probs
+ sample.response = tokenizer.decode(output_token_ids)
+ sample.response_length = len(output_token_ids)
+ sample.loss_mask = [1] * len(output_token_ids)
+
+ # TODO unify with Sample.update_from_meta_info
+ match choice["finish_reason"]:
+ case "stop" | "tool_calls":
+ sample.status = Sample.Status.COMPLETED
+ case "length":
+ sample.status = Sample.Status.TRUNCATED
+ case "abort":
+ sample.status = Sample.Status.ABORTED
+
+ return sample
diff --git a/miles/rollout/generate_utils/sample_utils.py b/miles/rollout/generate_utils/sample_utils.py
new file mode 100644
index 000000000..6a4e645be
--- /dev/null
+++ b/miles/rollout/generate_utils/sample_utils.py
@@ -0,0 +1,115 @@
+from copy import deepcopy
+from dataclasses import fields
+
+from miles.utils.types import Sample
+
+
+def merge_samples(samples: list[Sample], tokenizer) -> Sample:
+ acc = samples[0]
+ for sample in samples[1:]:
+ acc = _merge_sample_pair(acc, sample, tokenizer=tokenizer)
+ return acc
+
+
+def _merge_sample_pair(a: Sample, b: Sample, tokenizer) -> Sample:
+ """Merge two samples generated from sibling inference engine calls."""
+ a, b = deepcopy(a), deepcopy(b)
+
+ def _merge_equal_value(field):
+ x = getattr(a, field)
+ y = getattr(b, field)
+ assert x == y, f"{field} mismatch: a.{field}={x}, b.{field}={y}"
+ return x
+
+ def _fill_defaults(sample: Sample):
+ if sample.loss_mask is None:
+ sample.loss_mask = [1] * sample.response_length
+ if sample.rollout_log_probs is None:
+ sample.rollout_log_probs = [0.0] * sample.response_length
+
+ _fill_defaults(a)
+ _fill_defaults(b)
+
+ obs_len = len(b.tokens) - len(a.tokens) - b.response_length
+ obs_tokens = b.tokens[len(a.tokens) : len(a.tokens) + obs_len]
+ # TODO: is this acceptable?
+ obs_text = tokenizer.decode(obs_tokens)
+
+ try:
+ a.validate()
+ b.validate()
+ assert _startswith(short=a.prompt, long=b.prompt), "b.prompt must start with a.prompt"
+ assert _startswith(short=a.tokens, long=b.tokens), "b.tokens must start with a.tokens"
+ assert obs_len > 0, f"obs_len must be > 0, got {obs_len}"
+ if a.rollout_routed_experts is not None:
+ assert a.rollout_routed_experts.shape[0] <= b.rollout_routed_experts.shape[0]
+ assert a.status == Sample.Status.COMPLETED, f"a.status must be COMPLETED, got {a.status}"
+
+ return _create_with_all_fields(
+ Sample,
+ group_index=_merge_equal_value("group_index"),
+ index=_merge_equal_value("index"),
+ prompt=b.prompt,
+ tokens=b.tokens,
+ multimodal_inputs=_merge_equal_value("multimodal_inputs"),
+ multimodal_train_inputs=_merge_equal_value("multimodal_train_inputs"),
+ response=a.response + obs_text + b.response,
+ response_length=a.response_length + obs_len + b.response_length,
+ label=_merge_equal_value("label"),
+ reward=_merge_equal_value("reward"),
+ loss_mask=a.loss_mask + [0] * obs_len + b.loss_mask,
+ weight_versions=a.weight_versions + b.weight_versions,
+ rollout_log_probs=a.rollout_log_probs + [0.0] * obs_len + b.rollout_log_probs,
+ rollout_routed_experts=b.rollout_routed_experts,
+ remove_sample=_merge_equal_value("remove_sample"),
+ status=b.status,
+ metadata=_merge_equal_value("metadata"),
+ train_metadata=_merge_equal_value("train_metadata"),
+ non_generation_time=_merge_equal_value("non_generation_time"),
+ spec_info=_merge_spec_info(a.spec_info, b.spec_info),
+ prefix_cache_info=_merge_prefix_cache_info(a.prefix_cache_info, b.prefix_cache_info),
+ )
+ except AssertionError as e:
+ e.add_note(f"{a=} {b=}")
+ raise
+
+
+def _merge_spec_info(a: Sample.SpecInfo, b: Sample.SpecInfo) -> Sample.SpecInfo:
+ def _merge_plus_value(field):
+ return getattr(a, field) + getattr(b, field)
+
+ return _create_with_all_fields(
+ Sample.SpecInfo,
+ spec_accept_token_num=_merge_plus_value("spec_accept_token_num"),
+ spec_draft_token_num=_merge_plus_value("spec_draft_token_num"),
+ spec_verify_ct=_merge_plus_value("spec_verify_ct"),
+ completion_token_num=_merge_plus_value("completion_token_num"),
+ )
+
+
+def _merge_prefix_cache_info(a: Sample.PrefixCacheInfo, b: Sample.PrefixCacheInfo) -> Sample.PrefixCacheInfo:
+ def _merge_plus_value(field):
+ return getattr(a, field) + getattr(b, field)
+
+ return _create_with_all_fields(
+ Sample.PrefixCacheInfo,
+ cached_tokens=_merge_plus_value("cached_tokens"),
+ total_prompt_tokens=_merge_plus_value("total_prompt_tokens"),
+ )
+
+
+def _create_with_all_fields(cls, **kwargs):
+ expected = {f.name for f in fields(cls)}
+ actual = set(kwargs.keys())
+ assert (
+ expected == actual
+ ), f"{cls.__name__} field mismatch. Missing: {expected - actual}, Extra: {actual - expected}"
+ return cls(**kwargs)
+
+
+def _startswith(*, short, long) -> bool:
+ if isinstance(short, str) and isinstance(long, str):
+ return long.startswith(short)
+ if isinstance(short, list) and isinstance(long, list):
+ return (len(long) >= len(short)) and (long[: len(short)] == short)
+ raise NotImplementedError
diff --git a/miles/rollout/generate_utils/tokenize_utils.py b/miles/rollout/generate_utils/tokenize_utils.py
new file mode 100644
index 000000000..ca4f59938
--- /dev/null
+++ b/miles/rollout/generate_utils/tokenize_utils.py
@@ -0,0 +1,52 @@
+from typing import Any
+from transformers import AutoTokenizer
+from miles.rollout.generate_utils.tool_call_utils import tokenize_tool_responses
+
+
+_DUMMY_SYSTEM = {"role": "system", "content": "FOR CALCULATING ADDITIONAL TOKENS ONLY"}
+_DUMMY_USER = {"role": "user", "content": "FOR CALCULATING ADDITIONAL TOKENS ONLY"}
+_DUMMY_ASSISTANT = {"role": "assistant", "content": "FOR CALCULATING ADDITIONAL TOKENS ONLY"}
+
+
+def calc_generation_prompt_tokens(tokenizer: AutoTokenizer) -> list[int]:
+ messages = [_DUMMY_SYSTEM, _DUMMY_USER, _DUMMY_ASSISTANT]
+ with_generation_prompt = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
+ without_generation_prompt = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False)
+ assert with_generation_prompt[: len(without_generation_prompt)] == without_generation_prompt
+ return with_generation_prompt[len(without_generation_prompt) :]
+
+
+# TODO(jiajun): need e2e test to validate. According to https://zhuanlan.zhihu.com/p/1917126584806139373
+# Notice: This function will automatically trim think tokens if the model's chat template trim thinking parts. Like Qwen3.
+def _naive_calc_additional_tokens(
+ message: dict[str, Any], tokenizer: AutoTokenizer, add_generation_prompt: bool = True
+) -> list[int]:
+ prefix = [_DUMMY_SYSTEM, _DUMMY_USER, _DUMMY_ASSISTANT, _DUMMY_USER]
+ suffix = [_DUMMY_SYSTEM, _DUMMY_USER]
+ prefix_tokens = tokenizer.apply_chat_template(prefix, tokenize=True, add_special_tokens=False)
+ messages_tokens = tokenizer.apply_chat_template(
+ prefix + [message] + suffix, tokenize=True, add_special_tokens=False
+ )
+ suffix_tokens = tokenizer.apply_chat_template(suffix, tokenize=True, add_special_tokens=False)
+
+ response_tokens = messages_tokens[len(prefix_tokens) : -len(suffix_tokens)]
+ generation_prompt_tokens = calc_generation_prompt_tokens(tokenizer)
+ return response_tokens + generation_prompt_tokens
+
+
+# TODO(jiajun): need e2e test to validate.
+def tokenize_messages(
+ messages: list[dict[str, Any]],
+ tokenizer,
+ add_generation_prompt: bool = True,
+) -> list[int]:
+ token_ids = []
+ for message in messages:
+ if message["role"] == "assistant" or message["role"] == "user" or message["role"] == "system":
+ token_ids.extend(_naive_calc_additional_tokens(message, tokenizer, add_generation_prompt))
+ elif message["role"] == "tool":
+ token_ids.extend(tokenize_tool_responses([message], tokenizer))
+ else:
+ raise ValueError(f"Unsupported message role: {message['role']}")
+
+ return token_ids
diff --git a/miles/rollout/generate_utils/tool_call_utils.py b/miles/rollout/generate_utils/tool_call_utils.py
new file mode 100644
index 000000000..85ea87aea
--- /dev/null
+++ b/miles/rollout/generate_utils/tool_call_utils.py
@@ -0,0 +1,115 @@
+"""
+Utils to handle tool calls.
+"""
+
+import json
+import uuid
+from collections.abc import Callable
+from typing import Any
+
+from openai.types.chat import ChatCompletionMessageToolCall
+from pydantic import TypeAdapter
+from sglang.srt.entrypoints.openai.protocol import Tool
+from sglang.srt.function_call.core_types import ToolCallItem
+from sglang.srt.function_call.function_call_parser import FunctionCallParser
+
+from miles.utils.types import Sample
+
+_DUMMY_USER = {"role": "user", "content": "dummy"}
+
+
+def create_tool_call_parser(tool_specs, tool_call_parser):
+ return FunctionCallParser(
+ tools=TypeAdapter(list[Tool]).validate_python(tool_specs),
+ tool_call_parser=tool_call_parser,
+ )
+
+
+async def execute_tool_calls(
+ tool_calls: list[ToolCallItem | ChatCompletionMessageToolCall],
+ execute_one: Callable,
+) -> list[dict[str, Any]]:
+ tool_messages = []
+ for call in tool_calls:
+ tool_messages.append(await _execute_tool_call(call, execute_one))
+ return tool_messages
+
+
+async def _execute_tool_call(
+ call: ToolCallItem | ChatCompletionMessageToolCall, execute_one: Callable
+) -> dict[str, Any]:
+ if isinstance(call, ChatCompletionMessageToolCall):
+ name = call.function.name
+ params = json.loads(call.function.arguments) if call.function.arguments else {}
+ tool_call_id = call.id
+ elif isinstance(call, ToolCallItem):
+ name = call.name
+ params = json.loads(call.parameters) if call.parameters else {}
+ tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
+ else:
+ raise TypeError(f"Unsupported tool call type: {type(call)}")
+
+ result = await execute_one(name, params)
+ assert isinstance(result, str)
+
+ return {"role": "tool", "tool_call_id": tool_call_id, "content": result, "name": name}
+
+
+def update_sample_with_tool_responses(sample: Sample, tool_messages: list[dict[str, Any]], tokenizer):
+ next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer)
+ sample.response += tokenizer.decode(next_obs_tokens_ids)
+ sample.response_length += len(next_obs_tokens_ids)
+ sample.tokens += next_obs_tokens_ids
+ sample.loss_mask += [0] * len(next_obs_tokens_ids)
+ sample.rollout_log_probs += [0.0] * len(next_obs_tokens_ids)
+
+
+# TODO: very naive implementation, need the to-be-implemented e2e test to validate.
+def tokenize_tool_responses(
+ tool_messages: list[dict[str, Any]],
+ tokenizer,
+) -> list[int]:
+ return _tokenize_postfix_messages(tool_messages, tokenizer)
+
+
+def _tokenize_postfix_messages(
+ postfix_messages: list[dict[str, Any]],
+ tokenizer,
+) -> list[int]:
+ dummy_assistant = _build_dummy_assistant(postfix_messages)
+ base_messages = [_DUMMY_USER, dummy_assistant]
+
+ messages_without = base_messages
+ messages_with = base_messages + postfix_messages
+
+ tokens_with = tokenizer.apply_chat_template(messages_with, tokenize=True, add_generation_prompt=True)
+ tokens_without = tokenizer.apply_chat_template(messages_without, tokenize=True, add_generation_prompt=False)
+
+ assert tokens_with[: len(tokens_without)] == tokens_without, (
+ f"Fail to tokenize_tool_responses caused by token prefix mismatch. "
+ f"This can happen for thinking model or models with special chat template, "
+ f"and this simple example does not support it yet, "
+ f"since this means we cannot have a append-only token id list. "
+ f"{tokens_with=} {tokens_without=} "
+ f"{tokenizer.decode(tokens_with)=} {tokenizer.decode(tokens_without)=} "
+ )
+ return tokens_with[len(tokens_without) :]
+
+
+def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]:
+ return {
+ "role": "assistant",
+ "content": "",
+ "reasoning_content": " ",
+ "tool_calls": [
+ {
+ "id": resp.get("tool_call_id", f"call0000{i}"),
+ "type": "function",
+ "function": {
+ "name": resp.get("name", "dummy_func"),
+ "arguments": {},
+ },
+ }
+ for i, resp in enumerate(tool_responses)
+ ],
+ }
diff --git a/miles/rollout/inference_rollout/__init__.py b/miles/rollout/inference_rollout/__init__.py
new file mode 100644
index 000000000..33ccf17bf
--- /dev/null
+++ b/miles/rollout/inference_rollout/__init__.py
@@ -0,0 +1,2 @@
+# This is a refactor of the portions above generate-function in sglang_rollout.py,
+# and is give a different name to ensure both code exist at the same time.
diff --git a/miles/rollout/inference_rollout/compatibility.py b/miles/rollout/inference_rollout/compatibility.py
new file mode 100644
index 000000000..7711e0dd3
--- /dev/null
+++ b/miles/rollout/inference_rollout/compatibility.py
@@ -0,0 +1,84 @@
+import inspect
+from collections.abc import Callable
+
+from miles.rollout.base_types import (
+ GenerateFnInput,
+ GenerateFnOutput,
+ RolloutFnConstructorInput,
+ RolloutFnEvalOutput,
+ RolloutFnInput,
+ RolloutFnOutput,
+ RolloutFnTrainOutput,
+)
+from miles.utils.async_utils import run
+from miles.utils.misc import load_function
+
+
+class LegacyRolloutFnAdapter:
+ def __init__(self, input: RolloutFnConstructorInput, fn: Callable):
+ self.args = input.args
+ self.data_source = input.data_source
+ self.fn = fn
+
+ def __call__(self, input: RolloutFnInput) -> RolloutFnOutput:
+ output = self.fn(self.args, input.rollout_id, self.data_source, evaluation=input.evaluation)
+
+ # compatibility for legacy version
+ if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)):
+ output = RolloutFnEvalOutput(data=output) if input.evaluation else RolloutFnTrainOutput(samples=output)
+
+ return output
+
+
+def load_rollout_function(input: RolloutFnConstructorInput, path: str):
+ fn = load_function(path)
+
+ if inspect.isclass(fn):
+ return fn(input)
+ else:
+ return LegacyRolloutFnAdapter(input, fn)
+
+
+def call_rollout_function(fn, input: RolloutFnInput) -> RolloutFnOutput:
+ output = fn(input)
+
+ if inspect.iscoroutine(output):
+ output = run(output)
+
+ return output
+
+
+class LegacyGenerateFnAdapter:
+ def __init__(self, fn: Callable):
+ self.fn = fn
+ self._has_evaluation_param = "evaluation" in inspect.signature(fn).parameters
+
+ async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput:
+ if self._has_evaluation_param:
+ output = await self.fn(input.args, input.sample, input.sampling_params, evaluation=input.evaluation)
+ else:
+ output = await self.fn(input.args, input.sample, input.sampling_params)
+
+ if not isinstance(output, GenerateFnOutput):
+ output = GenerateFnOutput(samples=output)
+
+ return output
+
+
+def load_generate_function(path: str):
+ fn = load_function(path)
+ if fn is None:
+ return None
+
+ if inspect.isclass(fn):
+ return fn()
+ elif _is_legacy_generate_fn(fn):
+ return LegacyGenerateFnAdapter(fn)
+ else:
+ return fn
+
+
+def _is_legacy_generate_fn(fn: Callable) -> bool:
+ sig = inspect.signature(fn)
+ params = list(sig.parameters.keys())
+ return len(params) >= 3 and params[0] != "input"
diff --git a/miles/rollout/inference_rollout/inference_rollout_common.py b/miles/rollout/inference_rollout/inference_rollout_common.py
new file mode 100644
index 000000000..8518c6e02
--- /dev/null
+++ b/miles/rollout/inference_rollout/inference_rollout_common.py
@@ -0,0 +1,192 @@
+import asyncio
+import logging
+from argparse import Namespace
+from copy import deepcopy
+from typing import Any
+
+from miles.rollout.base_types import (
+ GenerateFnInput,
+ RolloutFnConstructorInput,
+ RolloutFnEvalInput,
+ RolloutFnEvalOutput,
+ RolloutFnInput,
+ RolloutFnOutput,
+ RolloutFnTrainInput,
+ RolloutFnTrainOutput,
+)
+from miles.rollout.generate_hub.single_turn import generate
+from miles.rollout.inference_rollout.compatibility import load_generate_function
+from miles.rollout.rm_hub import async_rm, batched_async_rm
+from miles.utils.processing_utils import load_processor, load_tokenizer
+from miles.utils.types import Sample
+
+logger = logging.getLogger(__name__)
+
+
+class GenerateState:
+ def __init__(self, args: Namespace) -> None:
+ # persistent state for the generation process
+ self.args = args
+ self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True)
+ self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True)
+
+ self.generate_fn_semaphore = asyncio.Semaphore(
+ args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine
+ )
+ self.sampling_params: dict[str, Any] = compute_sampling_params(
+ args,
+ temperature=args.rollout_temperature,
+ top_p=args.rollout_top_p,
+ top_k=args.rollout_top_k,
+ max_new_tokens=args.rollout_max_response_len,
+ )
+
+ self.generate_function = load_generate_function(args.custom_generate_function_path) or generate
+
+ self.reset()
+
+ def reset(self) -> None:
+ self.aborted = False
+
+
+async def generate_and_rm(
+ state: GenerateState,
+ sample: Sample | list[Sample],
+ sampling_params: dict[str, Any],
+ evaluation: bool = False,
+) -> Sample | list[Sample]:
+ args = state.args
+
+ # mask previous off-policy generation for partial rollout
+ if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0:
+ sample.loss_mask = [0] * sample.response_length
+
+ # For samples with existing response, check if they're complete
+ if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED:
+ assert sample.response is not None
+ if not args.group_rm:
+ assert sample.reward is not None
+ return sample
+
+ # generate
+ async with state.generate_fn_semaphore:
+ if state.aborted:
+ sample.status = Sample.Status.ABORTED
+ return sample
+
+ output = await state.generate_function(
+ GenerateFnInput(
+ state=state,
+ sample=sample,
+ sampling_params=deepcopy(sampling_params),
+ evaluation=evaluation,
+ )
+ )
+ sample = output.samples
+
+ # TODO change to `if not args.group_rm: do reward model` for more clarity after the refactor below
+ # for the rm that need the whole group, we will not do the rm here
+ if args.group_rm:
+ return sample
+
+ # TODO: unify the two branches into one if we decide to use list as output type
+ # multi samples
+ if isinstance(sample, list):
+ samples = sample
+ if any([sample.status == Sample.Status.ABORTED for sample in samples]):
+ return samples
+
+ # for multi agent system, the reward of some sample is calculated during generation.
+ samples_need_reward = [sample for sample in samples if sample.reward is None]
+ await batched_async_rm(args, samples_need_reward, inplace_set_reward_field=True)
+ return samples
+ else:
+ if sample.status == Sample.Status.ABORTED:
+ return sample
+ # for multi-turn environment, a reward could be assigned to the agent.
+ if sample.reward is None:
+ sample.reward = await async_rm(args, sample)
+
+ return sample
+
+
+async def generate_and_rm_group(
+ state: GenerateState, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False
+) -> list[Sample]:
+ args = state.args
+
+ if state.aborted:
+ return group
+
+ tasks = []
+ for idx, sample in enumerate(group):
+ current_sampling_params = sampling_params.copy()
+ if getattr(args, "sglang_enable_deterministic_inference", False):
+ current_sampling_params["sampling_seed"] = args.rollout_seed + idx
+ tasks.append(
+ asyncio.create_task(generate_and_rm(state, sample, current_sampling_params, evaluation=evaluation))
+ )
+
+ group = await asyncio.gather(*tasks)
+ if state.aborted:
+ return group
+
+ if args.group_rm:
+ await batched_async_rm(args, group, inplace_set_reward_field=True)
+
+ return group
+
+
+def compute_sampling_params(
+ args,
+ *,
+ # after unifying configuration, this can be further refactored
+ temperature,
+ top_p,
+ top_k,
+ max_new_tokens,
+):
+ return dict(
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ max_new_tokens=max_new_tokens,
+ stop=args.rollout_stop,
+ stop_token_ids=args.rollout_stop_token_ids,
+ skip_special_tokens=args.rollout_skip_special_tokens,
+ no_stop_trim=True,
+ spaces_between_special_tokens=False,
+ )
+
+
+class InferenceRolloutFn:
+ def __init__(self, input: RolloutFnConstructorInput):
+ self.data_source = input.data_source
+ self.state = GenerateState(input.args)
+ self.eval_prompt_dataset_cache = {}
+
+ async def __call__(self, input: RolloutFnInput) -> RolloutFnOutput:
+ if input.evaluation:
+ return await self._call_eval(input)
+ return await self._call_train(input)
+
+ async def _call_train(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput:
+ from miles.rollout.inference_rollout.inference_rollout_train import generate_rollout_async
+
+ output, aborted_samples = await generate_rollout_async(
+ self.state, input.rollout_id, self.data_source.get_samples
+ )
+ self.data_source.add_samples(aborted_samples)
+ return output
+
+ async def _call_eval(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput:
+ from miles.rollout.inference_rollout.inference_rollout_eval import eval_rollout_single_dataset
+
+ assert not self.state.args.group_rm, "Group RM is not supported for eval rollout"
+
+ coros = []
+ for dataset_cfg in getattr(self.state.args, "eval_datasets", []) or []:
+ coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.eval_prompt_dataset_cache))
+ results_list = await asyncio.gather(*coros)
+ results = {k: v for r in results_list for k, v in r.items()}
+ return RolloutFnEvalOutput(data=results)
diff --git a/miles/rollout/inference_rollout/inference_rollout_eval.py b/miles/rollout/inference_rollout/inference_rollout_eval.py
new file mode 100644
index 000000000..2d052be0a
--- /dev/null
+++ b/miles/rollout/inference_rollout/inference_rollout_eval.py
@@ -0,0 +1,112 @@
+import asyncio
+import copy
+import logging
+from typing import Any
+
+from tqdm import tqdm
+
+from miles.rollout.inference_rollout.inference_rollout_common import (
+ GenerateState,
+ compute_sampling_params,
+ generate_and_rm,
+)
+from miles.utils.data import Dataset
+from miles.utils.eval_config import EvalDatasetConfig
+from miles.utils.misc import as_completed_async
+from miles.utils.processing_utils import load_processor, load_tokenizer
+from miles.utils.types import Sample
+
+logger = logging.getLogger(__name__)
+
+
+async def eval_rollout_single_dataset(
+ state: GenerateState,
+ dataset_cfg: EvalDatasetConfig,
+ prompt_dataset_cache: dict[Any, Dataset],
+) -> dict[str, dict[str, list[Any]]]:
+ args = state.args
+ assert not args.group_rm, "Group RM is not supported for eval rollout"
+
+ cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template)
+ if cache_key not in prompt_dataset_cache:
+ tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True)
+ processor = load_processor(args.hf_checkpoint, trust_remote_code=True)
+ prompt_dataset_cache[cache_key] = Dataset(
+ path=dataset_cfg.path,
+ tokenizer=tokenizer,
+ processor=processor,
+ max_length=args.eval_max_prompt_len,
+ prompt_key=dataset_cfg.input_key,
+ label_key=dataset_cfg.label_key,
+ multimodal_keys=args.multimodal_keys,
+ metadata_key=dataset_cfg.metadata_key,
+ tool_key=dataset_cfg.tool_key,
+ apply_chat_template=args.apply_chat_template,
+ apply_chat_template_kwargs=args.apply_chat_template_kwargs,
+ )
+ dataset = prompt_dataset_cache[cache_key]
+
+ base_sampling_params = compute_sampling_params(
+ args,
+ temperature=dataset_cfg.temperature,
+ top_p=dataset_cfg.top_p,
+ top_k=dataset_cfg.top_k,
+ max_new_tokens=dataset_cfg.max_response_len,
+ )
+
+ tasks = []
+ # do multiple samples for eval prompts
+ sample_index = 0
+ for _i, prompt_sample in enumerate(dataset.samples):
+ for j in range(dataset_cfg.n_samples_per_eval_prompt):
+ # use the same prompt for multiple samples
+ sample = copy.deepcopy(prompt_sample)
+ sample.index = sample_index
+ sample_index += 1
+ sample.metadata = dataset_cfg.inject_metadata(getattr(sample, "metadata", None))
+ sampling_params = base_sampling_params
+ if getattr(args, "sglang_enable_deterministic_inference", False):
+ sampling_params = base_sampling_params.copy()
+ sampling_params["sampling_seed"] = args.rollout_seed + j
+ tasks.append(
+ asyncio.create_task(
+ generate_and_rm(
+ state,
+ sample,
+ sampling_params=sampling_params,
+ evaluation=True,
+ )
+ )
+ )
+
+ data = []
+ do_print = True
+ pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print)
+ async for sample in as_completed_async(tasks):
+ if do_print:
+ # TODO improve this after enhancing samples' type
+ s = (sample[0] if len(sample) > 0 else None) if isinstance(sample, list) else sample
+ if s is not None:
+ logger.info(
+ "eval_rollout_single_dataset example data: "
+ f"{[str(s.prompt) + s.response]} "
+ f"reward={s.reward}"
+ )
+ do_print = False
+ if isinstance(sample, list):
+ data.extend(sample)
+ else:
+ data.append(sample)
+ pbar.update(1)
+ pbar.close()
+
+ data.sort(key=lambda sample: sample.index)
+
+ reward_key = args.eval_reward_key or args.reward_key
+ return {
+ dataset_cfg.name: {
+ "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data],
+ "truncated": [sample.status == Sample.Status.TRUNCATED for sample in data],
+ "samples": data,
+ }
+ }
diff --git a/miles/rollout/inference_rollout/inference_rollout_train.py b/miles/rollout/inference_rollout/inference_rollout_train.py
new file mode 100644
index 000000000..bae94ec67
--- /dev/null
+++ b/miles/rollout/inference_rollout/inference_rollout_train.py
@@ -0,0 +1,146 @@
+import asyncio
+import logging
+from argparse import Namespace
+from collections.abc import Callable
+
+import sglang_router
+from packaging.version import parse
+from tqdm import tqdm
+
+from miles.rollout.base_types import RolloutFnTrainOutput
+from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter
+from miles.rollout.inference_rollout.inference_rollout_common import GenerateState, generate_and_rm_group
+from miles.utils.http_utils import get, post
+from miles.utils.misc import as_completed_async, load_function
+from miles.utils.types import Sample
+
+logger = logging.getLogger(__name__)
+
+
+async def abort(state: GenerateState, pendings: set, rollout_id: int) -> list[list[Sample]]:
+ args = state.args
+
+ assert not state.aborted
+ state.aborted = True
+
+ urls = await get_worker_urls(args)
+ logger.info(f"Abort request for {urls}")
+ await asyncio.gather(*[post(f"{url}/abort_request", {"abort_all": True}) for url in urls])
+
+ # make sure all the pending tasks are finished
+ aborted_samples = []
+ async for group in as_completed_async(pendings):
+ if not args.partial_rollout:
+ continue
+
+ # for partial rollout, collect the partial samples into the data buffer
+ for sample in group:
+ if sample.response and "start_rollout_id" not in sample.metadata:
+ sample.metadata["start_rollout_id"] = rollout_id
+ aborted_samples.append(group)
+
+ if args.partial_rollout:
+ logger.info(f"Collected {sum(len(x) for x in aborted_samples)} partial samples into the data buffer")
+
+ return aborted_samples
+
+
+async def get_worker_urls(args: Namespace):
+ if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router:
+ response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers")
+ return response["urls"]
+ else:
+ response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers")
+ return [worker["url"] for worker in response["workers"]]
+
+
+def submit_generate_tasks(state: GenerateState, samples: list[list[Sample]]):
+ return [
+ asyncio.create_task(
+ # submit a group of samples as a single task.
+ generate_and_rm_group(
+ state,
+ group,
+ sampling_params=state.sampling_params.copy(),
+ evaluation=False,
+ )
+ )
+ for group in samples
+ ]
+
+
+async def generate_rollout_async(
+ state: GenerateState, rollout_id: int, data_source: Callable[[int], list[list[Sample]]]
+) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]:
+ args = state.args
+ assert args.rollout_global_dataset
+
+ # instantiate data filters
+ dynamic_filter = load_function(args.dynamic_sampling_filter_path)
+
+ metric_gatherer = MetricGatherer()
+
+ # target_data_size is the total number of valid samples to get
+ target_data_size = args.rollout_batch_size
+
+ pendings = set()
+ data = []
+ all_data = []
+ do_print = True
+ pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation")
+ while len(data) < target_data_size:
+ while len(data) + len(pendings) < target_data_size:
+ # get samples from the buffer and submit the generation requests.
+ samples = data_source(args.over_sampling_batch_size)
+ pendings.update(submit_generate_tasks(state, samples))
+
+ # wait for the generation to finish
+ done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED)
+ for task in done:
+ group: list[Sample] = task.result()
+
+ if do_print:
+ sample = group[0][0] if isinstance(group[0], list) else group[0]
+ logger.info(
+ f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}",
+ )
+ do_print = False
+
+ assert len(group) == args.n_samples_per_prompt
+ all_data.append(group)
+ dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group)
+ if not dynamic_filter_output.keep:
+ metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason)
+ continue
+
+ # add the samples to the data
+ # NOTE: here we have not stored all the unused samples back to the data buffer.
+ if len(data) < target_data_size:
+ data.append(group)
+ pbar.update(args.n_samples_per_prompt)
+
+ pbar.close()
+ sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0]
+ logger.info(
+ f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}",
+ )
+
+ # there are still some unfinished requests, abort them
+ aborted_samples = await abort(state, pendings, rollout_id)
+
+ assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}"
+ data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index)
+ all_samples = sorted(
+ all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index
+ )
+
+ # reset the global state to prevent effects on the next rollout or eval.
+ state.reset()
+
+ if f := load_function(args.rollout_sample_filter_path):
+ f(args, data)
+ # There can be circumstances where users want to process all samples including filtered ones.
+ if f := load_function(args.rollout_all_samples_process_path):
+ f(args, all_samples, data_source)
+
+ return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples
diff --git a/miles/rollout/rm_hub/__init__.py b/miles/rollout/rm_hub/__init__.py
index 62b253dde..e9ee29db4 100644
--- a/miles/rollout/rm_hub/__init__.py
+++ b/miles/rollout/rm_hub/__init__.py
@@ -69,8 +69,18 @@ async def async_rm(args, sample: Sample, **kwargs):
async def batched_async_rm(
args,
samples: list[Sample],
+ inplace_set_reward_field: bool = False,
**kwargs,
-) -> list[int | float]:
+) -> list[int | float] | None:
+ if inplace_set_reward_field:
+ rewards = await batched_async_rm(args, samples, **kwargs)
+ for sample, reward in zip(samples, rewards, strict=True):
+ assert (
+ sample.reward is None
+ ), f"Overriding sample.reward from {sample.reward} to {reward}, is this intended?"
+ sample.reward = reward
+ return None
+
if args.custom_rm_path is not None:
# Ensure the custom reward function is implemented in batch mode
rm_function = load_function(args.custom_rm_path)
diff --git a/miles/router/middleware_hub/radix_tree.py b/miles/router/middleware_hub/radix_tree.py
index 6e722f1e2..67b9d6fe4 100644
--- a/miles/router/middleware_hub/radix_tree.py
+++ b/miles/router/middleware_hub/radix_tree.py
@@ -584,8 +584,8 @@ def retrieve_from_text(self, text: str, return_logprob: bool = True):
text: Input text to get tokens for
return_logprob: If True, also return log probabilities
Returns:
- List of token IDs corresponding to the input text if return_logprob is False.
- Tuple of (token_ids, logp) if return_logprob is True.
+ List of token (IDs, logp, loss_mask) corresponding to the input text
+ if return_logprob is False, all logp will be 0.0
"""
# Call find_longest_prefix to get the match result
result = self.find_longest_prefix(text)
diff --git a/miles/router/middleware_hub/radix_tree_middleware.py b/miles/router/middleware_hub/radix_tree_middleware.py
index db57f6456..b9d62d841 100644
--- a/miles/router/middleware_hub/radix_tree_middleware.py
+++ b/miles/router/middleware_hub/radix_tree_middleware.py
@@ -66,12 +66,14 @@ def __init__(self, app, *, router):
self.router.radix_tree = self.radix_tree
async def dispatch(self, request: Request, call_next):
-
path = request.url.path
+ if path == "/generate":
+ return await self._generate(request, call_next)
+ if path == "/retrieve_from_text":
+ return await self._retrieve_from_text(request)
+ return await call_next(request)
- if path != "/generate":
- return await call_next(request)
-
+ async def _generate(self, request: Request, call_next):
request_json = await request.json()
if "text" in request_json:
input_text = request_json.pop("text", "")
@@ -154,6 +156,23 @@ async def dispatch(self, request: Request, call_next):
print(f"[miles-router] Warning: Failed to cache trajectory: {e}")
return response
+ async def _retrieve_from_text(self, request: Request):
+ payload = await request.json()
+ text = payload.get("text", "")
+ token_ids, logp, loss_mask = self.radix_tree.retrieve_from_text(text, return_logprob=True)
+ result = {
+ "response": text,
+ "tokens": token_ids,
+ "loss_mask": loss_mask,
+ "rollout_logp": logp,
+ "token_length": len(token_ids),
+ "loss_mask_length": len(loss_mask),
+ }
+ assert (
+ len(token_ids) == len(logp) == len(loss_mask)
+ ), "Token IDs, logp, and loss mask must have the same length"
+ return JSONResponse(result)
+
async def postprocess_sample_with_radix_tree(args, sample: Sample, output: dict):
assert not args.partial_rollout, "Currently partial rollout is not supported when using miles router"
diff --git a/miles/router/router.py b/miles/router/router.py
index 2e8ecfc41..f092f359a 100644
--- a/miles/router/router.py
+++ b/miles/router/router.py
@@ -9,6 +9,7 @@
from fastapi.responses import JSONResponse
from starlette.responses import Response
+from miles.router.session.sessions import setup_session_routes
from miles.utils.misc import load_function
logger = logging.getLogger(__name__)
@@ -64,11 +65,12 @@ def __init__(self, args, verbose=False):
self.app.add_middleware(middleware, router=self)
def _setup_routes(self):
- """Setup all the HTTP routes"""
+ """Setup all the HTTP routes except catch-all proxy"""
# sglang-router api
self.app.post("/add_worker")(self.add_worker)
self.app.get("/list_workers")(self.list_workers)
- self.app.post("/retrieve_from_text")(self.retrieve_from_text)
+ # Session routes - must be registered before catch-all
+ setup_session_routes(self.app, self)
# Catch-all route for proxying to SGLang - must be registered LAST
self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy)
@@ -130,39 +132,51 @@ async def _health_check_loop(self):
async def proxy(self, request: Request, path: str):
"""Proxy all other requests to the SGLang router"""
- # Forward all other paths to SGLang router
+ result = await self._do_proxy(request, path)
+ return self._build_proxy_response(result)
+
+ async def _do_proxy(
+ self,
+ request: Request,
+ path: str,
+ body: bytes | None = None,
+ headers: dict | None = None,
+ ) -> dict:
+ """Core proxy logic. Returns dict with request_body, response_body, status_code, headers."""
worker_url = self._use_url()
url = f"{worker_url}/{path}"
- # Get request body and headers
- body = await request.body()
- headers = dict(request.headers)
+ if body is None:
+ body = await request.body()
+ if headers is None:
+ headers = dict(request.headers)
+ if body is not None:
+ headers = {k: v for k, v in headers.items() if k.lower() not in ("content-length", "transfer-encoding")}
try:
response = await self.client.request(request.method, url, content=body, headers=headers)
- # Eagerly read content so we can return JSON (not streaming)
content = await response.aread()
- content_type = response.headers.get("content-type", "")
- try:
- # Prefer parsing JSON if possible
- data = json.loads(content)
- return JSONResponse(
- content=data,
- status_code=response.status_code,
- headers=dict(response.headers),
- )
- except Exception:
- # Fall back to raw body with original content type
- return Response(
- content=content,
- status_code=response.status_code,
- headers=dict(response.headers),
- media_type=content_type or None,
- )
-
+ return {
+ "request_body": body,
+ "response_body": content,
+ "status_code": response.status_code,
+ "headers": dict(response.headers),
+ }
finally:
self._finish_url(worker_url)
+ def _build_proxy_response(self, result: dict) -> Response:
+ """Build HTTP response from proxy result."""
+ content = result["response_body"]
+ status_code = result["status_code"]
+ headers = result["headers"]
+ content_type = headers.get("content-type", "")
+ try:
+ data = json.loads(content)
+ return JSONResponse(content=data, status_code=status_code, headers=headers)
+ except Exception:
+ return Response(content=content, status_code=status_code, headers=headers, media_type=content_type)
+
async def add_worker(self, request: Request):
"""Add a new worker to the router.
Supports providing the URL via query string or JSON body.
@@ -197,28 +211,6 @@ async def list_workers(self, request: Request):
"""List all registered workers"""
return {"urls": list(self.worker_request_counts.keys())}
- async def retrieve_from_text(self, request: Request):
- """Get token information from text input"""
- body = await request.body()
- payload = json.loads(body) if body else {}
-
- text = payload.get("text", "")
-
- # Use radix tree's retrieve_from_text method (no need to fetch weight version here)
- token_ids, logp, loss_mask = self.radix_tree.retrieve_from_text(text, return_logprob=True)
-
- # Handle the result based on whether logp was requested
- result = {
- "tokens": token_ids, # token IDs
- "response": text, # The input text
- "loss_mask": loss_mask, # Loss mask for the tokens
- "token_length": len(token_ids),
- "loss_mask_length": len(loss_mask),
- "rollout_logp": logp,
- }
-
- return result
-
def _use_url(self):
"""Select worker URL with minimal active requests."""
diff --git a/miles/router/session/seq_trajectory.py b/miles/router/session/seq_trajectory.py
new file mode 100644
index 000000000..26c16fc15
--- /dev/null
+++ b/miles/router/session/seq_trajectory.py
@@ -0,0 +1,263 @@
+import copy
+import logging
+import uuid
+from typing import Any
+
+from pydantic import BaseModel, Field
+from transformers import AutoTokenizer
+
+from miles.rollout.generate_utils.tokenize_utils import tokenize_messages
+from miles.utils.chat_message_utils import calc_last_think_part_index
+
+logger = logging.getLogger(__name__)
+
+
+class TokenInfo(BaseModel):
+ tokens: list[str] = Field(default_factory=list)
+ token_ids: list[int] = Field(default_factory=list)
+ log_probs: list[float] = Field(default_factory=list)
+ loss_mask: list[int] = Field(default_factory=list)
+
+ def remove_tokens(self, start_index: int, end_index: int):
+ # Notice: the end index is exclusive.
+ self.tokens = self.tokens[start_index:end_index]
+ self.token_ids = self.token_ids[start_index:end_index]
+ self.log_probs = self.log_probs[start_index:end_index]
+ self.loss_mask = self.loss_mask[start_index:end_index]
+
+ def insert_tokens(self, tokens: list[str], token_ids: list[int], log_probs: list[float], loss_mask: list[int]):
+ self.tokens.extend(tokens)
+ self.token_ids.extend(token_ids)
+ self.log_probs.extend(log_probs)
+ self.loss_mask.extend(loss_mask)
+
+ def append(self, token: str, token_id: int, log_prob: float, loss_mask: int):
+ self.tokens.append(token)
+ self.token_ids.append(token_id)
+ self.log_probs.append(log_prob)
+ self.loss_mask.append(loss_mask)
+
+ def __add__(self, other: "TokenInfo") -> "TokenInfo":
+ return TokenInfo(
+ tokens=self.tokens + other.tokens,
+ token_ids=self.token_ids + other.token_ids,
+ log_probs=self.log_probs + other.log_probs,
+ loss_mask=self.loss_mask + other.loss_mask,
+ )
+
+ @staticmethod
+ def remove_last_assistant_think_and_handle_truncated_message(
+ token_info: "TokenInfo", model_name: str
+ ) -> "TokenInfo":
+ raise NotImplementedError("Not implemented yet.")
+ tmp = copy.deepcopy(token_info)
+ start, end = calc_last_think_part_index(tmp.token_ids, model_name)
+ if start is None:
+ # No think part found, or think part is truncated, we will not trim.
+ return tmp
+ # Notice: after trimming, the old answer tokens cannot be used to calculate loss, so logp and loss mask are set to 0.
+ if end is not None:
+ tmp.remove_tokens(start, end + 1)
+ if end + 1 < len(token_info.token_ids):
+ n = len(token_info.token_ids)
+ tmp.insert_tokens(
+ token_info.tokens[end + 1 :],
+ token_info.token_ids[end + 1 :],
+ [0.0] * (n - end - 1),
+ [0] * (n - end - 1),
+ )
+ # Handle truncated message.
+
+ return tmp
+
+
+class Turn(BaseModel):
+ """
+ A turn is a multiple message turn, end with an assistant message.
+ """
+
+ messages: list[dict[str, Any]]
+ prompt_tokens: TokenInfo
+ response_tokens: TokenInfo
+
+ def __init__(
+ self,
+ messages: list[dict[str, Any]],
+ prompt_tokens: TokenInfo,
+ response_tokens: TokenInfo,
+ ):
+ super().__init__(
+ messages=messages,
+ prompt_tokens=prompt_tokens,
+ response_tokens=response_tokens,
+ )
+ assert (
+ len(messages) > 0 and messages[-1]["role"] == "assistant"
+ ), "The last message must be an assistant message."
+
+ def match_prefix_messages_and_return_remaining(self, other: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
+ """
+ If the messages match with other's prefix, return the remaining messages. Otherwise, return None.
+ """
+ if len(self.messages) < len(other):
+ return None
+ for i in range(len(other)):
+ if self.messages[i] != other[i]:
+ return None
+ return self.messages[len(other) :]
+
+ def handle_token_out_for_next_turn(self, model_name: str) -> TokenInfo:
+ raise NotImplementedError("Not implemented yet.")
+ trimmed_tokens = TokenInfo.remove_last_assistant_think(self.prompt_tokens + self.response_tokens, model_name)
+ return trimmed_tokens
+
+
+class SeqTrajectory(BaseModel):
+ """
+ Sequence trajectory state.
+ Can only maintain the token info for the last turn.
+ It should not have any state. Which means `token_ids` should always include the final chat templated text.
+ (Note: if seq trajectory has state, when a reqeust crash, bug will happen.)
+ """
+
+ num_turns: int = 0
+ model_name: str = ""
+ # History for all turns.
+ turns: list[Turn] = Field(default_factory=list)
+ records: list[dict[str, Any]] = Field(default_factory=list)
+
+ def insert_new_turn(self, turn: Turn):
+ self.turns.append(turn)
+ self.num_turns += 1
+
+ def insert_new_record(self, record: dict[str, Any]):
+ self.records.append(record)
+
+ def match_prefix_turns_and_return_last_turn(
+ self, messages: list[dict[str, Any]], n: int | None = None
+ ) -> tuple[Turn, list[dict[str, Any]]]:
+ if n is None:
+ n = self.num_turns
+ assert n > 0, "n must be greater than 0"
+ remain_messages = messages
+ for i in range(n):
+ turn = self.turns[i]
+ remain_messages = turn.match_prefix_messages_and_return_remaining(remain_messages)
+ if remain_messages is None:
+ raise ValueError(
+ "Under sequence trajectory, messages prefix should match, but unmatched messages: {remain_messages}"
+ )
+ return self.turns[n - 1], remain_messages
+
+ def calc_prompt_tokens_info(
+ self,
+ messages: list[dict[str, Any]],
+ tokenizer: AutoTokenizer,
+ cross_turn_token_out: bool = False,
+ inherit_last_assistant: bool = False,
+ ) -> TokenInfo:
+ if cross_turn_token_out and self.num_turns > 0:
+ if inherit_last_assistant:
+ raise NotImplementedError("Not implemented yet.")
+ turn, remain_messages = self.match_prefix_turns_and_return_last_turn(messages)
+ token_info = turn.handle_token_out_for_next_turn(self.model_name)
+ else:
+ if self.num_turns >= 2:
+ turn, remain_messages = self.match_prefix_turns_and_return_last_turn(messages, self.num_turns - 1)
+ old_token_ids = turn.prompt_tokens.token_ids + turn.response_tokens.token_ids
+ else:
+ remain_messages = messages
+ old_token_ids = []
+ new_token_ids = tokenize_messages(remain_messages, tokenizer, add_generation_prompt=True)
+ token_ids = old_token_ids + new_token_ids
+ log_probs = [0.0] * len(token_ids)
+ loss_mask = [0] * len(token_ids)
+ token_info = TokenInfo(
+ tokens=tokenizer.convert_ids_to_tokens(token_ids),
+ token_ids=token_ids,
+ log_probs=log_probs,
+ loss_mask=loss_mask,
+ )
+ else:
+ token_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
+ log_probs = [0.0] * len(token_ids)
+ loss_mask = [0] * len(token_ids)
+ token_info = TokenInfo(
+ tokens=tokenizer.convert_ids_to_tokens(token_ids),
+ token_ids=token_ids,
+ log_probs=log_probs,
+ loss_mask=loss_mask,
+ )
+
+ return token_info
+
+ def get_last_turn_token_info(self) -> TokenInfo:
+ if not self.turns:
+ return TokenInfo()
+ return self.turns[-1].prompt_tokens + self.turns[-1].response_tokens
+
+
+class SeqTrajectoryManager:
+ def __init__(self, args, tokenizer: AutoTokenizer):
+ self.sessions: dict[str, SeqTrajectory] = {}
+ self.args = args
+ self.tokenizer = tokenizer
+
+ def create_session(self) -> str:
+ session_id = uuid.uuid4().hex
+ self.sessions[session_id] = SeqTrajectory()
+ return session_id
+
+ def get_session_by_id(self, session_id: str) -> TokenInfo | None:
+ session = self.sessions.get(session_id)
+ if session is None:
+ return None
+ return session.get_last_turn_token_info()
+
+ def get_session_records(self, session_id: str) -> list[dict[str, Any]] | None:
+ session = self.sessions.get(session_id)
+ if session is None:
+ return None
+ return session.records
+
+ def calc_prompt_tokens(self, session_id: str, messages: list[dict[str, Any]]) -> TokenInfo | None:
+ # Notice: Sequence trajectory manager will support the prefix of input messages match with the only history.
+ session = self.sessions.get(session_id)
+ if session is None:
+ return None
+ cross_turn_token_out = getattr(self.args, "cross_turn_token_out", False)
+ inherit_last_assistant = getattr(self.args, "inherit_last_assistant", False)
+ token_info: TokenInfo = session.calc_prompt_tokens_info(
+ messages,
+ self.tokenizer,
+ cross_turn_token_out=cross_turn_token_out,
+ inherit_last_assistant=inherit_last_assistant,
+ )
+ return token_info
+ # if remain_messages is None:
+ # TODO(jiajun): Should we truncate think part of the last turn's assistant message, if the new turn does not include any new message?
+ # Turn 1: sys | user | assistant | tool | assistant
+ # Turn 2: sys | user | assistant | tool | assistant | ???
+ # Noral: sys | user | assistant | tool | assistant | ???
+ # Not hard to fix, but temporarily leave this TODO.
+ # raise ValueError("Currently, we do not support consecutive assistant message input.")
+
+ def delete_session_by_id(self, session_id: str) -> bool:
+ session = self.sessions.pop(session_id, None)
+ if session is None:
+ return False
+ return True
+
+ def add_record(self, session_id: str, turn: Turn) -> bool:
+ session = self.sessions.get(session_id)
+ if session is None:
+ raise ValueError(f"Session {session_id} not found.")
+ session.insert_new_turn(turn)
+ return True
+
+ def add_session_record(self, session_id: str, record: dict[str, Any]) -> bool:
+ session = self.sessions.get(session_id)
+ if session is None:
+ raise ValueError(f"Session {session_id} not found.")
+ session.insert_new_record(record)
+ return True
diff --git a/miles/router/session/sessions.py b/miles/router/session/sessions.py
new file mode 100644
index 000000000..6573127a3
--- /dev/null
+++ b/miles/router/session/sessions.py
@@ -0,0 +1,115 @@
+import json
+import time
+from typing import TYPE_CHECKING
+
+from fastapi import Request
+from fastapi.responses import JSONResponse, Response
+from pydantic import BaseModel
+from transformers import AutoTokenizer
+
+from miles.router.session.seq_trajectory import SeqTrajectoryManager, TokenInfo, Turn
+
+if TYPE_CHECKING:
+ from miles.router.router import MilesRouter
+
+
+class SessionRecord(BaseModel):
+ timestamp: float
+ method: str
+ path: str
+ request: dict
+ response: dict
+ status_code: int
+
+
+class GetSessionResponse(BaseModel):
+ session_id: str
+ records: dict
+ session_records: list[SessionRecord] | None = None
+
+
+def setup_session_routes(app, router: "MilesRouter"):
+
+ # TODO temporary hack before @guapisolo implements TITO
+ # ============================= HACK START ===============================
+ # Lazy load tokenizer only when needed (for tests that don't have hf_checkpoint)
+ tokenizer = AutoTokenizer.from_pretrained(router.args.hf_checkpoint, trust_remote_code=True)
+ manager = SeqTrajectoryManager(router.args, tokenizer)
+
+ # ============================= HACK END ===============================
+
+ @app.post("/sessions")
+ async def create_session():
+ session_id = manager.create_session()
+ return {"session_id": session_id}
+
+ @app.get("/sessions/{session_id}")
+ async def get_session(session_id: str):
+ token_info = manager.get_session_by_id(session_id)
+ if token_info is None:
+ return JSONResponse(status_code=404, content={"error": "session not found"})
+ session_records = manager.get_session_records(session_id)
+ return GetSessionResponse(
+ session_id=session_id,
+ records=token_info.model_dump(),
+ session_records=session_records,
+ )
+
+ @app.delete("/sessions/{session_id}")
+ async def delete_session(session_id: str):
+ status = manager.delete_session_by_id(session_id)
+ if not status:
+ return JSONResponse(status_code=404, content={"error": "session not found"})
+ return Response(status_code=204)
+
+ @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
+ async def session_proxy(request: Request, session_id: str, path: str):
+ body = await request.body()
+ request_body = json.loads(body) if body else {}
+
+ prompt_token_info = TokenInfo()
+ response_token_info = TokenInfo()
+ if "messages" in request_body and "input_ids" not in request_body:
+ prompt_token_info = manager.calc_prompt_tokens(session_id, request_body["messages"])
+ if prompt_token_info is None:
+ return JSONResponse(status_code=404, content={"error": "session not found"})
+ token_ids = prompt_token_info.token_ids
+ request_body["input_ids"] = token_ids
+ body = json.dumps(request_body).encode("utf-8")
+
+ result = await router._do_proxy(request, path, body=body)
+
+ response = json.loads(result["response_body"])
+
+ choice = response.get("choices", [{}])[0]
+ messages = request_body["messages"] + [choice["message"]]
+
+ assert "logprobs" in choice and "content" in choice["logprobs"], "logprobs must be in choice"
+ logprobs_content = choice["logprobs"]["content"]
+
+ for item in logprobs_content:
+ if "token" in item and "token_id" not in item:
+ item["token_id"] = tokenizer.convert_tokens_to_ids(item["token"])
+ response_token_info.append(item["token"], item["token_id"], item["logprob"], 1)
+
+ manager.add_record(
+ session_id,
+ Turn(
+ messages=messages,
+ prompt_tokens=prompt_token_info,
+ response_tokens=response_token_info,
+ ),
+ )
+ manager.add_session_record(
+ session_id,
+ SessionRecord(
+ timestamp=time.time(),
+ method=request.method,
+ path=path,
+ request=request_body,
+ response=response,
+ status_code=result["status_code"],
+ ).model_dump(),
+ )
+
+ return router._build_proxy_response(result)
diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py
index 79b2c419c..071020292 100644
--- a/miles/utils/arguments.py
+++ b/miles/utils/arguments.py
@@ -10,8 +10,10 @@
from miles.backends.sglang_utils.arguments import add_sglang_arguments
from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args
+from miles.utils.environ import enable_experimental_rollout_refactor
from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list
from miles.utils.logging_utils import configure_logger
+from miles.utils.misc import load_function
logger = logging.getLogger(__name__)
@@ -204,7 +206,11 @@ def add_rollout_arguments(parser):
parser.add_argument(
"--rollout-function-path",
type=str,
- default="miles.rollout.sglang_rollout.generate_rollout",
+ default=(
+ "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn"
+ if enable_experimental_rollout_refactor()
+ else "miles.rollout.sglang_rollout.generate_rollout"
+ ),
help=(
"Path to the rollout generation function."
"You should use this model to create your own custom rollout function, "
@@ -1344,6 +1350,20 @@ def add_ci_arguments(parser):
)
return parser
+ def add_user_provided_function_arguments(parser):
+ args_partial, _ = parser.parse_known_args()
+ for path in [
+ args_partial.rollout_function_path,
+ args_partial.custom_generate_function_path,
+ ]:
+ try:
+ fn = load_function(path)
+ except (ModuleNotFoundError, ValueError):
+ continue
+ if fn is not None and callable(getattr(fn, "add_arguments", None)):
+ fn.add_arguments(parser)
+ return parser
+
def add_sglang_tp_size():
temp_parser = argparse.ArgumentParser(add_help=False)
temp_parser.add_argument("--rollout-num-gpus-per-engine", type=int, default=1)
@@ -1374,6 +1394,8 @@ def add_sglang_tp_size():
parser = add_prefill_decode_disaggregation_arguments(parser)
parser = add_ci_arguments(parser)
parser = add_custom_megatron_plugins_arguments(parser)
+ if enable_experimental_rollout_refactor():
+ parser = add_user_provided_function_arguments(parser)
reset_arg(
parser,
"--custom-config-path",
diff --git a/miles/utils/chat_message_utils.py b/miles/utils/chat_message_utils.py
new file mode 100644
index 000000000..815b1dca4
--- /dev/null
+++ b/miles/utils/chat_message_utils.py
@@ -0,0 +1,39 @@
+# These are helper functions for think token lookup.
+THINK_TOKEN_START = {
+ "qwen3": ("", 151667),
+}
+THINK_TOKEN_END = {
+ "qwen3": ("", 151668),
+}
+
+
+def get_think_token_start(model_name: str) -> tuple[str, int]:
+ return THINK_TOKEN_START[model_name]
+
+
+def get_think_token_end(model_name: str) -> tuple[str, int]:
+ return THINK_TOKEN_END[model_name]
+
+
+def calc_last_think_part_index(tokens: list[int], model_name: str) -> tuple[int | None, int | None]:
+ start_index = None
+ end_index = None
+ for i in range(len(tokens)):
+ if tokens[i] == get_think_token_start(model_name)[1]:
+ start_index = i
+
+ if start_index is None:
+ # No think tokens found, no strip.
+ return None, None
+
+ for i in range(start_index + 1, len(tokens)):
+ if tokens[i] == get_think_token_end(model_name)[1]:
+ end_index = i
+
+ # If think part being truncated, end_index would be None.
+ return start_index, end_index
+
+
+def check_is_truncated_message(tokens: list[int], model_name: str) -> bool:
+ # TODO: handle this later.
+ pass
diff --git a/miles/utils/environ.py b/miles/utils/environ.py
new file mode 100644
index 000000000..35d1f350e
--- /dev/null
+++ b/miles/utils/environ.py
@@ -0,0 +1,14 @@
+import os
+
+_printed_experimental_rollout_refactor = False
+
+
+def enable_experimental_rollout_refactor() -> bool:
+ result = bool(int(os.environ.get("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", "0")))
+
+ global _printed_experimental_rollout_refactor
+ if result and not _printed_experimental_rollout_refactor:
+ print("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1 is enabled (experimental feature)")
+ _printed_experimental_rollout_refactor = True
+
+ return result
diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py
index 2b3e6e192..0abdbbf59 100644
--- a/miles/utils/http_utils.py
+++ b/miles/utils/http_utils.py
@@ -162,11 +162,15 @@ def _next_actor():
return actor
-async def _post(client, url, payload, max_retries=60):
+async def _post(client, url, payload, max_retries=60, action="post"):
retry_count = 0
while retry_count < max_retries:
try:
- response = await client.post(url, json=payload or {})
+ if action in ("delete", "get"):
+ assert not payload
+ response = await getattr(client, action)(url)
+ else:
+ response = await getattr(client, action)(url, json=payload or {})
response.raise_for_status()
try:
output = response.json()
@@ -240,8 +244,8 @@ def __init__(self, concurrency: int):
timeout=httpx.Timeout(None),
)
- async def do_post(self, url, payload, max_retries=60):
- return await _post(self._client, url, payload, max_retries)
+ async def do_post(self, url, payload, max_retries=60, action="post"):
+ return await _post(self._client, url, payload, max_retries, action=action)
# Create actors per node
created = []
@@ -265,7 +269,8 @@ async def do_post(self, url, payload, max_retries=60):
_post_actors = created
-async def post(url, payload, max_retries=60):
+# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution)
+async def post(url, payload, max_retries=60, action="post"):
# If distributed mode is enabled and actors exist, dispatch via Ray.
if _distributed_post_enabled and _post_actors:
try:
@@ -274,15 +279,16 @@ async def post(url, payload, max_retries=60):
actor = _next_actor()
if actor is not None:
# Use a thread to avoid blocking the event loop on ray.get
- obj_ref = actor.do_post.remote(url, payload, max_retries)
+ obj_ref = actor.do_post.remote(url, payload, max_retries, action=action)
return await asyncio.to_thread(ray.get, obj_ref)
except Exception as e:
logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})")
# fall through to local
- return await _post(_http_client, url, payload, max_retries)
+ return await _post(_http_client, url, payload, max_retries, action=action)
+# TODO unify w/ `post` to add retries and remote-execution
async def get(url):
response = await _http_client.get(url)
response.raise_for_status()
diff --git a/miles/utils/misc.py b/miles/utils/misc.py
index c0a96d636..bae72ec0d 100644
--- a/miles/utils/misc.py
+++ b/miles/utils/misc.py
@@ -1,17 +1,55 @@
+import asyncio
import importlib
import subprocess
+from contextlib import contextmanager
import ray
from miles.utils.http_utils import is_port_available
+# Mainly used for test purpose where `load_function` needs to load many in-flight generated functions
+class FunctionRegistry:
+ def __init__(self):
+ self._registry: dict[str, object] = {}
+
+ @contextmanager
+ def temporary(self, name: str, fn: object):
+ self._register(name, fn)
+ try:
+ yield
+ finally:
+ self._unregister(name)
+
+ def get(self, name: str) -> object | None:
+ return self._registry.get(name)
+
+ def _register(self, name: str, fn: object) -> None:
+ assert name not in self._registry
+ self._registry[name] = fn
+
+ def _unregister(self, name: str) -> None:
+ assert name in self._registry
+ self._registry.pop(name)
+
+
+function_registry = FunctionRegistry()
+
+
+# TODO may rename to `load_object` since it can be used to load things like tool_specs
def load_function(path):
"""
- Load a function from a module.
+ Load a function from registry or module.
:param path: The path to the function, e.g. "module.submodule.function".
:return: The function object.
"""
+ if path is None:
+ return None
+
+ registered = function_registry.get(path)
+ if registered is not None:
+ return registered
+
module_path, _, attr = path.rpartition(".")
module = importlib.import_module(module_path)
return getattr(module, attr)
@@ -30,8 +68,9 @@ def __call__(cls, *args, **kwargs):
cls._instances[cls] = instance
return cls._instances[cls]
- def clear_instances(cls):
- cls._instances = {}
+ @staticmethod
+ def clear_all_instances():
+ SingletonMeta._instances.clear()
def exec_command(cmd: str, capture_output: bool = False) -> str | None:
@@ -92,3 +131,8 @@ def should_run_periodic_action(
step = rollout_id + 1
return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0)
+
+
+async def as_completed_async(tasks):
+ for coro in asyncio.as_completed(tasks):
+ yield await coro
diff --git a/miles/utils/test_utils/__init__.py b/miles/utils/test_utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py
new file mode 100644
index 000000000..7602230d6
--- /dev/null
+++ b/miles/utils/test_utils/mock_sglang_server.py
@@ -0,0 +1,250 @@
+import asyncio
+import re
+import time
+import uuid
+from collections.abc import Callable
+from contextlib import contextmanager
+from dataclasses import asdict, dataclass
+
+from fastapi import FastAPI, Request
+from fastapi.responses import JSONResponse
+from pydantic import TypeAdapter
+from sglang.srt.entrypoints.openai.protocol import Tool
+from sglang.srt.function_call.function_call_parser import FunctionCallParser
+from transformers import AutoTokenizer
+
+from miles.utils.http_utils import find_available_port
+from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer
+
+
+@dataclass(frozen=True)
+class ProcessResultMetaInfo:
+ weight_version: str | None = None
+ routed_experts: str | None = None
+ spec_accept_token_num: int | None = None
+ spec_draft_token_num: int | None = None
+ spec_verify_ct: int | None = None
+
+ def to_dict(self) -> dict:
+ return {k: v for k, v in asdict(self).items() if v is not None}
+
+
+@dataclass(frozen=True)
+class ProcessResult:
+ text: str
+ finish_reason: str = "stop"
+ cached_tokens: int = 0
+ meta_info: ProcessResultMetaInfo = ProcessResultMetaInfo()
+
+
+ProcessFn = Callable[[str], ProcessResult]
+
+
+class MockSGLangServer:
+ def __init__(
+ self,
+ model_name: str,
+ process_fn: ProcessFn,
+ host: str,
+ port: int,
+ latency: float = 0.0,
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ self.process_fn = process_fn
+ self.host = host
+ self.port = port or find_available_port(30000)
+ self.latency = latency
+
+ self.app = FastAPI()
+ self._server: UvicornThreadServer | None = None
+
+ self.request_log: list[dict] = []
+ self._concurrency = Counter()
+
+ self._setup_routes()
+
+ @property
+ def max_concurrent(self) -> int:
+ return self._concurrency.max_value
+
+ def reset_stats(self):
+ self.request_log.clear()
+ self._concurrency.reset()
+
+ def start(self):
+ self._server = UvicornThreadServer(self.app, host=self.host, port=self.port)
+ self._server.start()
+
+ def stop(self):
+ if self._server is not None:
+ self._server.stop()
+
+ @property
+ def url(self) -> str:
+ return f"http://{self.host}:{self.port}"
+
+ def _setup_routes(self):
+ @self.app.post("/generate")
+ async def generate(request: Request):
+ return await self._handle_generate_like_request(request, self._compute_generate_response)
+
+ @self.app.post("/v1/chat/completions")
+ async def chat_completions(request: Request):
+ return await self._handle_generate_like_request(request, self._compute_chat_completions_response)
+
+ @self.app.get("/health")
+ async def health():
+ return JSONResponse(content={"status": "ok"})
+
+ @self.app.post("/abort_request")
+ async def abort_request(_request: Request):
+ return JSONResponse(content={"status": "ok"})
+
+ async def _handle_generate_like_request(self, request: Request, compute_fn: Callable[[dict], dict]):
+ payload = await request.json()
+ self.request_log.append(payload)
+ with self._concurrency.track():
+ if self.latency > 0:
+ await asyncio.sleep(self.latency)
+ response = compute_fn(payload)
+ return JSONResponse(content=response)
+
+ def _compute_generate_response(self, payload: dict) -> dict:
+ assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True"
+ input_ids = payload.get("input_ids", [])
+
+ prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False)
+ process_result = self.process_fn(prompt_str)
+ output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False)
+
+ prompt_tokens = len(input_ids)
+ completion_tokens = len(output_ids)
+
+ finish_reason_dict = {"type": process_result.finish_reason}
+ if process_result.finish_reason == "length":
+ finish_reason_dict["length"] = completion_tokens
+
+ output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)]
+
+ meta_info = {
+ "finish_reason": finish_reason_dict,
+ "prompt_tokens": prompt_tokens,
+ "cached_tokens": process_result.cached_tokens,
+ "completion_tokens": completion_tokens,
+ "output_token_logprobs": output_token_logprobs,
+ **process_result.meta_info.to_dict(),
+ }
+
+ return {"text": process_result.text, "meta_info": meta_info}
+
+ def _compute_chat_completions_response(self, payload: dict) -> dict:
+ messages = payload.get("messages", [])
+ tools = payload.get("tools")
+
+ prompt_str = self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True, tools=tools
+ )
+
+ print(f"_messages: {messages=}", flush=True)
+ print(f"_compute_chat_completions_response: {prompt_str=}", flush=True)
+ process_result = self.process_fn(prompt_str)
+ output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False)
+
+ logprobs_content = [
+ {"token": self.tokenizer.convert_ids_to_tokens(tid), "logprob": -1 / 128 * i}
+ for i, tid in enumerate(output_ids)
+ ]
+
+ finish_reason = process_result.finish_reason
+ tool_calls = None
+ if tools and finish_reason == "stop":
+ parser = FunctionCallParser(
+ tools=TypeAdapter(list[Tool]).validate_python(tools),
+ tool_call_parser="qwen25",
+ )
+ message_content, parsed_calls = parser.parse_non_stream(process_result.text)
+ if parsed_calls:
+ finish_reason = "tool_calls"
+ tool_calls = [
+ {
+ "id": f"call{i:05d}",
+ "type": "function",
+ "function": {"name": call.name, "arguments": call.parameters or "{}"},
+ }
+ for i, call in enumerate(parsed_calls)
+ ]
+ else:
+ message_content = process_result.text
+
+ return {
+ "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": "mock-model",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": message_content,
+ "tool_calls": tool_calls,
+ },
+ "logprobs": {"content": logprobs_content},
+ "finish_reason": finish_reason,
+ }
+ ],
+ }
+
+
+class Counter:
+ def __init__(self):
+ self._current = 0
+ self._max = 0
+
+ @property
+ def max_value(self) -> int:
+ return self._max
+
+ def reset(self):
+ self._current = 0
+ self._max = 0
+
+ @contextmanager
+ def track(self):
+ self._current += 1
+ self._max = max(self._max, self._current)
+ try:
+ yield
+ finally:
+ self._current -= 1
+
+
+def default_process_fn(prompt: str) -> ProcessResult:
+ match = re.search(r"What is 1\+(\d+)\?", prompt)
+ if match:
+ num = int(match.group(1))
+ ans = 1 + num
+ return ProcessResult(text=f"\\boxed{{{ans}}}", finish_reason="stop")
+ return ProcessResult(text="I don't understand.", finish_reason="stop")
+
+
+@contextmanager
+def with_mock_server(
+ model_name: str = "Qwen/Qwen3-0.6B",
+ process_fn: ProcessFn = default_process_fn,
+ host: str = "127.0.0.1",
+ port: int | None = None,
+ latency: float = 0.0,
+):
+ server = MockSGLangServer(
+ model_name=model_name,
+ process_fn=process_fn,
+ host=host,
+ port=port,
+ latency=latency,
+ )
+ try:
+ server.start()
+ yield server
+ finally:
+ server.stop()
diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py
new file mode 100644
index 000000000..26c18738e
--- /dev/null
+++ b/miles/utils/test_utils/mock_tools.py
@@ -0,0 +1,367 @@
+import json
+import logging
+
+from transformers import AutoTokenizer
+
+from miles.utils.test_utils.mock_sglang_server import ProcessResult
+
+logger = logging.getLogger(__name__)
+SAMPLE_TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_year",
+ "description": "Get current year",
+ "parameters": {
+ "type": "object",
+ "properties": {},
+ "required": [],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_temperature",
+ "description": "Get temperature for a location",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ "required": ["location"],
+ },
+ },
+ },
+]
+
+
+def _get_year(params: dict) -> str:
+ assert len(params) == 0
+ return json.dumps({"year": 2026})
+
+
+def _get_temperature(params: dict) -> str:
+ temps = {"Mars": -60, "Earth": 15}
+ location = params.get("location")
+ assert location in temps, f"Unknown location: {location}"
+ return json.dumps({"temperature": temps[location]})
+
+
+TOOL_EXECUTORS = {
+ "get_year": _get_year,
+ "get_temperature": _get_temperature,
+}
+
+
+async def execute_tool_call(name: str, params: dict) -> str:
+ return TOOL_EXECUTORS[name](params)
+
+
+_SYSTEM_PROMPT = (
+ "<|im_start|>system\n"
+ "# Tools\n"
+ "\n"
+ "You may call one or more functions to assist with the user query.\n"
+ "\n"
+ "You are provided with function signatures within XML tags:\n"
+ "\n"
+ '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n'
+ '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n'
+ "\n"
+ "\n"
+ "For each function call, return a json object with function name and arguments within XML tags:\n"
+ "\n"
+ '{"name": , "arguments": }\n'
+ "<|im_end|>\n"
+)
+
+
+_TOKENIZER = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)
+
+
+class TwoTurnStub:
+ """Stub for 2-turn: get_year + get_temperature(Mars) -> final answer"""
+
+ USER_QUESTION = "What is 42 + year + temperature?"
+
+ FIRST_RESPONSE = (
+ "Let me get the year and temperature first.\n"
+ "\n"
+ '{"name": "get_year", "arguments": {}}\n'
+ "\n"
+ "\n"
+ '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n'
+ "<|im_end|>\n"
+ )
+
+ FIRST_TOOL_RESPONSE = (
+ "<|im_start|>user\n"
+ "\n"
+ '{"year": 2026}\n'
+ "\n"
+ "\n"
+ '{"temperature": -60}\n'
+ "<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+
+ SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008."
+
+ FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n"
+ SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE
+
+ PROMPT = [{"role": "user", "content": USER_QUESTION}]
+
+ FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"]
+ SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"]
+
+ FIRST_RESPONSE_CONTENT = "Let me get the year and temperature first."
+ FIRST_TOOL_CALLS_OPENAI_FORMAT = [
+ {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"},
+ {
+ "id": "call00001",
+ "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"},
+ "type": "function",
+ },
+ ]
+
+ OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}]
+
+ OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [
+ {
+ "content": FIRST_RESPONSE_CONTENT,
+ "refusal": None,
+ "role": "assistant",
+ "annotations": None,
+ "audio": None,
+ "function_call": None,
+ "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT,
+ },
+ {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"},
+ {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"},
+ ]
+
+ @staticmethod
+ def process_fn(prompt: str) -> ProcessResult:
+ prompt_response_pairs = {
+ TwoTurnStub.FIRST_PROMPT: TwoTurnStub.FIRST_RESPONSE,
+ TwoTurnStub.SECOND_PROMPT: TwoTurnStub.SECOND_RESPONSE,
+ }
+
+ for expect_prompt, response in prompt_response_pairs.items():
+ if prompt == expect_prompt:
+ return ProcessResult(text=response, finish_reason="stop")
+
+ raise ValueError(f"Unexpected {prompt=}")
+
+
+class ThreeTurnStub:
+ """Stub for 3-turn: get_year + get_temperature(Mars) -> get_temperature(Earth) -> final answer"""
+
+ USER_QUESTION = "What is 42 + year + Mars temperature + Earth temperature?"
+
+ FIRST_RESPONSE = (
+ "Let me get the year and Mars temperature first.\n"
+ "\n"
+ '{"name": "get_year", "arguments": {}}\n'
+ "\n"
+ "\n"
+ '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n'
+ "<|im_end|>\n"
+ )
+
+ SECOND_RESPONSE = (
+ "Now let me get Earth temperature.\n"
+ "\n"
+ '{"name": "get_temperature", "arguments": {"location": "Earth"}}\n'
+ "<|im_end|>\n"
+ )
+
+ FIRST_TOOL_RESPONSE = (
+ "<|im_start|>user\n"
+ "\n"
+ '{"year": 2026}\n'
+ "\n"
+ "\n"
+ '{"temperature": -60}\n'
+ "<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+
+ SECOND_TOOL_RESPONSE = (
+ "<|im_start|>user\n"
+ "\n"
+ '{"temperature": 15}\n'
+ "<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+
+ THIRD_RESPONSE = "The answer is: 42 + 2026 + -60 + 15 = 2023."
+
+ FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n"
+ SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE
+ THIRD_PROMPT = SECOND_PROMPT + SECOND_RESPONSE + SECOND_TOOL_RESPONSE
+
+ PROMPT = [{"role": "user", "content": USER_QUESTION}]
+
+ FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"]
+ SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"]
+ THIRD_PROMPT_TOKEN_IDS = _TOKENIZER(THIRD_PROMPT, add_special_tokens=False)["input_ids"]
+
+ FIRST_RESPONSE_CONTENT = "Let me get the year and Mars temperature first."
+ FIRST_TOOL_CALLS_OPENAI_FORMAT = [
+ {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"},
+ {
+ "id": "call00001",
+ "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"},
+ "type": "function",
+ },
+ ]
+
+ SECOND_RESPONSE_CONTENT = "Now let me get Earth temperature."
+ SECOND_TOOL_CALLS_OPENAI_FORMAT = [
+ {
+ "id": "call00000",
+ "function": {"arguments": '{"location": "Earth"}', "name": "get_temperature"},
+ "type": "function",
+ },
+ ]
+
+ OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}]
+
+ OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [
+ {
+ "content": FIRST_RESPONSE_CONTENT,
+ "refusal": None,
+ "role": "assistant",
+ "annotations": None,
+ "audio": None,
+ "function_call": None,
+ "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT,
+ },
+ {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"},
+ {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"},
+ ]
+
+ OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT = OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT + [
+ {
+ "content": SECOND_RESPONSE_CONTENT,
+ "refusal": None,
+ "role": "assistant",
+ "annotations": None,
+ "audio": None,
+ "function_call": None,
+ "tool_calls": SECOND_TOOL_CALLS_OPENAI_FORMAT,
+ },
+ {"role": "tool", "tool_call_id": "call00000", "content": '{"temperature": 15}', "name": "get_temperature"},
+ ]
+
+ @staticmethod
+ def process_fn(prompt: str) -> ProcessResult:
+ prompt_response_pairs = {
+ ThreeTurnStub.FIRST_PROMPT: ThreeTurnStub.FIRST_RESPONSE,
+ ThreeTurnStub.SECOND_PROMPT: ThreeTurnStub.SECOND_RESPONSE,
+ ThreeTurnStub.THIRD_PROMPT: ThreeTurnStub.THIRD_RESPONSE,
+ }
+
+ for expect_prompt, response in prompt_response_pairs.items():
+ if prompt == expect_prompt:
+ return ProcessResult(text=response, finish_reason="stop")
+
+ raise ValueError(f"Unexpected {prompt=}")
+
+
+class ThinkingThreeTurnStub:
+ """3-turn stub with a think tag in the assistant response."""
+
+ USER_QUESTION = ThreeTurnStub.USER_QUESTION
+ THINK_PREFIX = "\nLet me think.\n\n\n"
+ FOURTH_USER_MESSAGE = "Thanks."
+
+ FIRST_RESPONSE = THINK_PREFIX + ThreeTurnStub.FIRST_RESPONSE
+ SECOND_RESPONSE = ThreeTurnStub.SECOND_RESPONSE
+ THIRD_RESPONSE = ThreeTurnStub.THIRD_RESPONSE
+ FOURTH_RESPONSE = "You're welcome."
+
+ FIRST_TOOL_RESPONSE = ThreeTurnStub.FIRST_TOOL_RESPONSE
+ SECOND_TOOL_RESPONSE = ThreeTurnStub.SECOND_TOOL_RESPONSE
+
+ FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n"
+ SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE
+ THIRD_PROMPT = SECOND_PROMPT + SECOND_RESPONSE + SECOND_TOOL_RESPONSE
+ FOURTH_PROMPT = (
+ THIRD_PROMPT
+ + THIRD_RESPONSE
+ + "<|im_end|>\n"
+ + "<|im_start|>user\n"
+ + FOURTH_USER_MESSAGE
+ + "<|im_end|>\n"
+ + "<|im_start|>assistant\n"
+ )
+
+ PROMPT = [{"role": "user", "content": USER_QUESTION}]
+
+ FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"]
+ SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"]
+ THIRD_PROMPT_TOKEN_IDS = _TOKENIZER(THIRD_PROMPT, add_special_tokens=False)["input_ids"]
+ FOURTH_PROMPT_TOKEN_IDS = _TOKENIZER(FOURTH_PROMPT, add_special_tokens=False)["input_ids"]
+
+ FIRST_RESPONSE_CONTENT = THINK_PREFIX + ThreeTurnStub.FIRST_RESPONSE_CONTENT
+ FIRST_TOOL_CALLS_OPENAI_FORMAT = ThreeTurnStub.FIRST_TOOL_CALLS_OPENAI_FORMAT
+ SECOND_RESPONSE_CONTENT = ThreeTurnStub.SECOND_RESPONSE_CONTENT
+ SECOND_TOOL_CALLS_OPENAI_FORMAT = ThreeTurnStub.SECOND_TOOL_CALLS_OPENAI_FORMAT
+
+ OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}]
+
+ OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [
+ {
+ "content": FIRST_RESPONSE_CONTENT,
+ "refusal": None,
+ "role": "assistant",
+ "annotations": None,
+ "audio": None,
+ "function_call": None,
+ "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT,
+ },
+ {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"},
+ {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"},
+ ]
+
+ OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT = OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT + [
+ {
+ "content": SECOND_RESPONSE_CONTENT,
+ "refusal": None,
+ "role": "assistant",
+ "annotations": None,
+ "audio": None,
+ "function_call": None,
+ "tool_calls": SECOND_TOOL_CALLS_OPENAI_FORMAT,
+ },
+ {"role": "tool", "tool_call_id": "call00000", "content": '{"temperature": 15}', "name": "get_temperature"},
+ ]
+ OPENAI_MESSAGES_FOURTH_TURN_FROM_CLIENT = OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT + [
+ {
+ "content": THIRD_RESPONSE,
+ "refusal": None,
+ "role": "assistant",
+ "annotations": None,
+ "audio": None,
+ "function_call": None,
+ "tool_calls": None,
+ },
+ {"role": "user", "content": FOURTH_USER_MESSAGE},
+ ]
+
+ @staticmethod
+ def process_fn(prompt: str) -> ProcessResult:
+ prompt_response_pairs = {
+ ThinkingThreeTurnStub.FIRST_PROMPT: ThinkingThreeTurnStub.FIRST_RESPONSE,
+ ThinkingThreeTurnStub.SECOND_PROMPT: ThinkingThreeTurnStub.SECOND_RESPONSE,
+ ThinkingThreeTurnStub.THIRD_PROMPT: ThinkingThreeTurnStub.THIRD_RESPONSE,
+ ThinkingThreeTurnStub.FOURTH_PROMPT: ThinkingThreeTurnStub.FOURTH_RESPONSE,
+ }
+
+ for expect_prompt, response in prompt_response_pairs.items():
+ if prompt == expect_prompt:
+ return ProcessResult(text=response, finish_reason="stop")
+
+ raise ValueError(f"Unexpected {prompt=}")
diff --git a/miles/utils/test_utils/uvicorn_thread_server.py b/miles/utils/test_utils/uvicorn_thread_server.py
new file mode 100644
index 000000000..904343c98
--- /dev/null
+++ b/miles/utils/test_utils/uvicorn_thread_server.py
@@ -0,0 +1,49 @@
+import asyncio
+import socket
+import threading
+import time
+
+import uvicorn
+
+
+class UvicornThreadServer:
+ def __init__(self, app, host: str, port: int):
+ self._app = app
+ self.host = host
+ self.port = port
+ self._server: uvicorn.Server | None = None
+ self._thread: threading.Thread | None = None
+
+ @property
+ def url(self) -> str:
+ return f"http://{self.host}:{self.port}"
+
+ def start(self) -> None:
+ config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info")
+ self._server = uvicorn.Server(config)
+
+ def run() -> None:
+ asyncio.run(self._server.serve())
+
+ self._thread = threading.Thread(target=run, daemon=True)
+ self._thread.start()
+ self._wait_for_port_open()
+
+ def stop(self) -> None:
+ if self._server is not None:
+ self._server.should_exit = True
+ if self._thread is not None and self._thread.is_alive():
+ self._thread.join(timeout=2.0)
+
+ def _wait_for_port_open(self) -> None:
+ for _ in range(50):
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ result = sock.connect_ex((self.host, self.port))
+ sock.close()
+ if result == 0:
+ return
+ except Exception:
+ pass
+ time.sleep(0.1)
+ raise RuntimeError(f"Failed to start server on {self.url}")
diff --git a/miles/utils/types.py b/miles/utils/types.py
index 0a2531a7a..5200d625e 100644
--- a/miles/utils/types.py
+++ b/miles/utils/types.py
@@ -145,6 +145,24 @@ def get_reward_value(self, args) -> float:
def effective_response_length(self):
return sum(self.loss_mask) if self.loss_mask is not None else self.response_length
+ def validate(self):
+ assert self.response_length >= 0, f"response_length must be >= 0, got {self.response_length}"
+ assert (
+ len(self.tokens) >= self.response_length
+ ), f"tokens length ({len(self.tokens)}) must be >= response_length ({self.response_length})"
+ if self.loss_mask is not None:
+ assert (
+ len(self.loss_mask) == self.response_length
+ ), f"loss_mask length ({len(self.loss_mask)}) != response_length ({self.response_length})"
+ if self.rollout_log_probs is not None:
+ assert (
+ len(self.rollout_log_probs) == self.response_length
+ ), f"rollout_log_probs length ({len(self.rollout_log_probs)}) != response_length ({self.response_length})"
+ if self.rollout_routed_experts is not None:
+ actual = len(self.rollout_routed_experts)
+ expect = len(self.tokens) - 1
+ assert actual == expect, f"rollout_routed_experts length ({actual}) != len(tokens) - 1 ({expect})"
+
def update_from_meta_info(self, args, meta_info: dict):
"""
Update the sample with new information from meta_info returned by the rollout engine.
diff --git a/requirements.txt b/requirements.txt
index 2c20195fc..dacd51132 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,6 +6,7 @@ mcp[cli]
memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml
omegaconf
pillow
+pybase64
pylatexenc
pyyaml
qwen_vl_utils # for VLM
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 000000000..8b1378917
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/ci/gpu_lock_exec.py b/tests/ci/gpu_lock_exec.py
index 9507e2e85..20379f76a 100644
--- a/tests/ci/gpu_lock_exec.py
+++ b/tests/ci/gpu_lock_exec.py
@@ -19,11 +19,14 @@ def main():
_execute_print_only(args)
return
- fd_locks = _try_acquire(args)
+ if args.count == 0 and not args.devices:
+ print("[gpu_lock_exec] Do not acquire GPU since count=0", flush=True)
+ else:
+ fd_locks = _try_acquire(args)
- dev_list = ",".join(str(x.gpu_id) for x in fd_locks)
- os.environ[args.target_env_name] = dev_list
- print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True)
+ dev_list = ",".join(str(x.gpu_id) for x in fd_locks)
+ os.environ[args.target_env_name] = dev_list
+ print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True)
_os_execvp(args)
diff --git a/tests/e2e/.gitkeep b/tests/e2e/.gitkeep
new file mode 100644
index 000000000..615f2b076
--- /dev/null
+++ b/tests/e2e/.gitkeep
@@ -0,0 +1 @@
+# TODO: may move e2e tests to this folder
\ No newline at end of file
diff --git a/tests/fast/__init__.py b/tests/fast/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/conftest.py b/tests/fast/conftest.py
new file mode 100644
index 000000000..4cb30e91f
--- /dev/null
+++ b/tests/fast/conftest.py
@@ -0,0 +1,15 @@
+import os
+
+import pytest
+
+from tests.fast.fixtures.generation_fixtures import generation_env
+from tests.fast.fixtures.rollout_fixtures import rollout_env
+
+_ = rollout_env, generation_env
+
+
+@pytest.fixture(autouse=True)
+def enable_experimental_rollout_refactor():
+ os.environ["MILES_EXPERIMENTAL_ROLLOUT_REFACTOR"] = "1"
+ yield
+ os.environ.pop("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", None)
diff --git a/tests/fast/fixtures/__init__.py b/tests/fast/fixtures/__init__.py
new file mode 100644
index 000000000..8b1378917
--- /dev/null
+++ b/tests/fast/fixtures/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py
new file mode 100644
index 000000000..816371ee3
--- /dev/null
+++ b/tests/fast/fixtures/generation_fixtures.py
@@ -0,0 +1,274 @@
+"""
+Fixtures to test custom-generate-function
+"""
+
+from argparse import Namespace
+from contextlib import contextmanager
+from dataclasses import dataclass
+from types import SimpleNamespace
+from typing import Any
+from unittest.mock import patch
+
+import pytest
+import requests
+
+from miles.rollout.base_types import GenerateFnInput
+from miles.rollout.inference_rollout.compatibility import load_generate_function
+from miles.rollout.inference_rollout.inference_rollout_common import GenerateState
+from miles.router.router import MilesRouter
+from miles.utils.async_utils import run
+from miles.utils.http_utils import find_available_port, init_http_client
+from miles.utils.misc import SingletonMeta
+from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server
+from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer
+from miles.utils.types import Sample
+
+MODEL_NAME = "Qwen/Qwen3-0.6B"
+RESPONSE_TEXT = "\\boxed{8}"
+DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7}
+
+VARIANT_TO_GENERATE_FN_PATH = {
+ "old_sglang_rollout": "miles.rollout.sglang_rollout.generate",
+ "single_turn": "miles.rollout.generate_hub.single_turn.generate",
+ "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn.generate",
+ "multi_turn_multi_samples": "miles.rollout.generate_hub.multi_turn.generate",
+ "agentic_tool_call_single_sample": "miles.rollout.generate_hub.agentic_tool_call.generate",
+ "agentic_tool_call_multi_samples": "miles.rollout.generate_hub.agentic_tool_call.generate",
+}
+
+
+def extra_argv_for_variant(
+ variant: str,
+ *,
+ custom_generate_function_path: str | None = None,
+ generate_max_turns: int = 16,
+ generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS",
+ generate_tool_call_parser: str = "qwen25",
+ generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call",
+) -> list[str]:
+ argv = [
+ "--custom-generate-function-path",
+ custom_generate_function_path or VARIANT_TO_GENERATE_FN_PATH[variant],
+ ]
+
+ if variant in (
+ "multi_turn_single_sample",
+ "multi_turn_multi_samples",
+ "agentic_tool_call_single_sample",
+ "agentic_tool_call_multi_samples",
+ ):
+ argv += [
+ "--generate-max-turns",
+ str(generate_max_turns),
+ "--generate-tool-specs-path",
+ generate_tool_specs_path,
+ "--generate-execute-tool-function-path",
+ generate_execute_tool_function_path,
+ ]
+ if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"):
+ argv += ["--generate-tool-call-parser", generate_tool_call_parser]
+ if variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples"):
+ argv.append("--generate-multi-samples")
+
+ return argv
+
+
+def listify(x):
+ return x if isinstance(x, list) else [x]
+
+
+def make_sample(
+ *,
+ prompt: str | list[dict] = "What is 1+7?",
+ tokens: list[int] | None = None,
+ response: str = "",
+ response_length: int = 0,
+ status: Sample.Status = Sample.Status.PENDING,
+ multimodal_inputs: dict | None = None,
+) -> Sample:
+ return Sample(
+ prompt=prompt,
+ tokens=tokens or [],
+ response=response,
+ response_length=response_length,
+ status=status,
+ multimodal_inputs=multimodal_inputs,
+ )
+
+
+@dataclass
+class GenerateEnv:
+ args: Namespace
+ mock_server: Any
+
+
+@dataclass
+class GenerateResult:
+ sample: Sample | list[Sample]
+ requests: list[dict]
+
+
+def run_generate(
+ env: GenerateEnv,
+ sample: Sample,
+ sampling_params: dict[str, Any] | None = None,
+ *,
+ variant: str = "single_turn",
+) -> GenerateResult:
+ env.mock_server.request_log.clear()
+ result_sample = run(
+ _call_generate(
+ env.args,
+ sample,
+ sampling_params or DEFAULT_SAMPLING_PARAMS,
+ variant=variant,
+ )
+ )
+ return GenerateResult(sample=result_sample, requests=list(env.mock_server.request_log))
+
+
+async def _call_generate(
+ args: Namespace,
+ sample: Sample,
+ sampling_params: dict[str, Any],
+ *,
+ variant: str = "single_turn",
+) -> Sample:
+ generate_fn = load_generate_function(VARIANT_TO_GENERATE_FN_PATH[variant])
+ state = GenerateState(args)
+ input = GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params.copy(), evaluation=False)
+ output = await generate_fn(input)
+ return output.samples
+
+
+def make_args(
+ *,
+ variant: str,
+ router_port: int,
+ use_rollout_routing_replay: bool = False,
+ sglang_speculative_algorithm: str | None = None,
+ model_name: str = MODEL_NAME,
+ extra_argv: list[str] | None = None,
+ custom_generate_function_path: str | None = None,
+ generate_max_turns: int = 16,
+ generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS",
+ generate_tool_call_parser: str = "qwen25",
+ generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call",
+ rollout_max_context_len: int | None = None,
+) -> Namespace:
+ argv = [
+ "pytest",
+ "--train-backend",
+ "fsdp",
+ "--rollout-batch-size",
+ "1",
+ "--num-rollout",
+ "1",
+ "--rollout-num-gpus",
+ "1",
+ "--rollout-num-gpus-per-engine",
+ "1",
+ "--hf-checkpoint",
+ model_name,
+ "--prompt-data",
+ "/dev/null",
+ "--rm-type",
+ "math",
+ "--sglang-router-ip",
+ "127.0.0.1",
+ "--sglang-router-port",
+ str(router_port),
+ "--rollout-max-response-len",
+ "16",
+ ]
+ if use_rollout_routing_replay:
+ argv.append("--use-rollout-routing-replay")
+ if sglang_speculative_algorithm:
+ argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm])
+ if rollout_max_context_len is not None:
+ argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)])
+
+ argv.extend(
+ extra_argv_for_variant(
+ variant,
+ custom_generate_function_path=custom_generate_function_path,
+ generate_max_turns=generate_max_turns,
+ generate_tool_specs_path=generate_tool_specs_path,
+ generate_tool_call_parser=generate_tool_call_parser,
+ generate_execute_tool_function_path=generate_execute_tool_function_path,
+ )
+ )
+
+ if extra_argv:
+ argv.extend(extra_argv)
+
+ from miles.utils.arguments import parse_args
+
+ with patch("sys.argv", argv):
+ args = parse_args()
+
+ init_http_client(args)
+ return args
+
+
+@contextmanager
+def with_miles_router(backend_url: str, model_name: str):
+ router_args = SimpleNamespace(
+ miles_router_max_connections=10,
+ miles_router_timeout=30,
+ miles_router_middleware_paths=[],
+ rollout_health_check_interval=60,
+ miles_router_health_check_failure_threshold=3,
+ hf_checkpoint=model_name,
+ )
+ router = MilesRouter(router_args)
+
+ port = find_available_port(31000)
+ server = UvicornThreadServer(router.app, host="127.0.0.1", port=port)
+ server.start()
+
+ url = f"http://127.0.0.1:{port}"
+ requests.post(f"{url}/add_worker", json={"url": backend_url})
+
+ try:
+ yield port
+ finally:
+ server.stop()
+
+
+@pytest.fixture
+def generation_env(request, variant):
+ SingletonMeta.clear_all_instances()
+ params = getattr(request, "param", {})
+ args_kwargs = params.get("args_kwargs", {})
+ model_name = args_kwargs.get("model_name", MODEL_NAME)
+ custom_generate_function_path = VARIANT_TO_GENERATE_FN_PATH[variant]
+
+ def process_fn(_):
+ x = params.get("process_fn_kwargs", {})
+ return ProcessResult(
+ text=x.get("response_text", RESPONSE_TEXT),
+ finish_reason=x.get("finish_reason", "stop"),
+ cached_tokens=x.get("cached_tokens", 0),
+ meta_info=ProcessResultMetaInfo(
+ weight_version=x.get("weight_version"),
+ routed_experts=x.get("routed_experts"),
+ spec_accept_token_num=x.get("spec_accept_token_num"),
+ spec_draft_token_num=x.get("spec_draft_token_num"),
+ spec_verify_ct=x.get("spec_verify_ct"),
+ ),
+ )
+
+ with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server:
+ with with_miles_router(mock_server.url, model_name) as router_port:
+ other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"}
+ args = make_args(
+ variant=variant,
+ router_port=router_port,
+ model_name=model_name,
+ custom_generate_function_path=custom_generate_function_path,
+ **other_args_kwargs,
+ )
+ yield GenerateEnv(args=args, mock_server=mock_server)
+
+ SingletonMeta.clear_all_instances()
diff --git a/tests/fast/fixtures/rollout_fixtures.py b/tests/fast/fixtures/rollout_fixtures.py
new file mode 100644
index 000000000..44d8a50d7
--- /dev/null
+++ b/tests/fast/fixtures/rollout_fixtures.py
@@ -0,0 +1,127 @@
+"""
+Fixtures to test rollout-function
+"""
+
+import json
+from argparse import Namespace
+from collections.abc import Iterator
+from contextlib import contextmanager
+from dataclasses import dataclass
+from pathlib import Path
+from unittest.mock import patch
+
+import pytest
+import requests
+
+from miles.rollout.data_source import DataSource, RolloutDataSourceWithBuffer
+from miles.router.router import MilesRouter
+from miles.utils.arguments import parse_args
+from miles.utils.http_utils import find_available_port, init_http_client
+from miles.utils.misc import SingletonMeta
+from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, with_mock_server
+from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer
+
+
+@dataclass(frozen=True)
+class RolloutEnvConfig:
+ extra_argv: list[str] | None = None
+ data_rows: list[dict] | None = None
+ latency: float = 0.0
+
+
+@dataclass(frozen=True)
+class RolloutEnv:
+ args: Namespace
+ data_source: DataSource
+ mock_server: MockSGLangServer
+
+
+def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | None = None) -> Namespace:
+ argv = [
+ "pytest",
+ "--train-backend",
+ "fsdp",
+ "--rollout-batch-size",
+ "1",
+ "--n-samples-per-prompt",
+ "1",
+ "--num-rollout",
+ "1",
+ "--rollout-num-gpus",
+ "1",
+ "--rollout-num-gpus-per-engine",
+ "1",
+ "--hf-checkpoint",
+ "Qwen/Qwen3-0.6B",
+ "--prompt-data",
+ data_path,
+ "--input-key",
+ "input",
+ "--label-key",
+ "label",
+ "--rm-type",
+ "math",
+ "--eval-prompt-data",
+ "toy",
+ data_path,
+ "--use-miles-router",
+ "--sglang-router-ip",
+ "127.0.0.1",
+ "--sglang-router-port",
+ str(router_port),
+ "--rollout-max-response-len",
+ "16",
+ ] + (extra_argv or [])
+ with patch("sys.argv", argv):
+ args = parse_args()
+ args.miles_router_middleware_paths = []
+ init_http_client(args)
+ return args
+
+
+@contextmanager
+def _with_miles_router(args: Namespace) -> Iterator[UvicornThreadServer]:
+ router = MilesRouter(args, verbose=False)
+ server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port)
+ try:
+ server.start()
+ yield server
+ finally:
+ server.stop()
+
+
+def _write_jsonl(path: str, rows: list[dict]) -> None:
+ Path(path).write_text("".join(json.dumps(row, ensure_ascii=False) + "\n" for row in rows), encoding="utf-8")
+
+
+DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}]
+
+
+@pytest.fixture
+def rollout_env(tmp_path, request) -> RolloutEnv:
+ config = request.param
+ assert isinstance(config, RolloutEnvConfig)
+
+ data_rows = config.data_rows or DEFAULT_DATA_ROWS
+
+ data_path = str(tmp_path / "data.jsonl")
+ _write_jsonl(data_path, data_rows)
+
+ router_port = find_available_port(20000)
+ args = _build_args(data_path=data_path, router_port=router_port, extra_argv=config.extra_argv)
+
+ SingletonMeta.clear_all_instances()
+
+ with with_mock_server(model_name=args.hf_checkpoint, latency=config.latency) as mock_server:
+ with _with_miles_router(args) as router_server:
+ r = requests.post(
+ f"{router_server.url}/add_worker",
+ params={"url": mock_server.url},
+ timeout=5.0,
+ )
+ r.raise_for_status()
+
+ data_source = RolloutDataSourceWithBuffer(args)
+ yield RolloutEnv(args=args, data_source=data_source, mock_server=mock_server)
+
+ SingletonMeta.clear_all_instances()
diff --git a/tests/fast/rollout/__init__.py b/tests/fast/rollout/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/rollout/generate_hub/__init__.py b/tests/fast/rollout/generate_hub/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/rollout/generate_hub/test_multi_turn.py b/tests/fast/rollout/generate_hub/test_multi_turn.py
new file mode 100644
index 000000000..c3ef3e855
--- /dev/null
+++ b/tests/fast/rollout/generate_hub/test_multi_turn.py
@@ -0,0 +1,656 @@
+from copy import deepcopy
+from dataclasses import dataclass, replace
+from itertools import groupby
+
+import numpy as np
+import pybase64
+import pytest
+from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate
+from transformers import AutoTokenizer
+
+from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo
+from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, ThinkingThreeTurnStub, ThreeTurnStub, TwoTurnStub
+from miles.utils.types import Sample
+
+_ = generation_env, SAMPLE_TOOLS, TwoTurnStub, ThreeTurnStub
+
+
+def is_agentic_variant(variant: str) -> bool:
+ return variant in ("agentic_tool_call_single_sample", "agentic_tool_call_multi_samples")
+
+
+# ------------------------------------ fixtures and consts ----------------------------------------
+
+
+MODEL_NAME = "Qwen/Qwen3-0.6B"
+DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7}
+TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
+
+
+@pytest.fixture(
+ params=[
+ "multi_turn_single_sample",
+ "multi_turn_multi_samples",
+ "agentic_tool_call_single_sample",
+ "agentic_tool_call_multi_samples",
+ ]
+)
+def variant(request):
+ return request.param
+
+
+@dataclass(frozen=True)
+class SampleParsedChunk:
+ tokens_decoded_str: str
+ loss_mask_value: int
+ rollout_log_probs: list[float]
+
+
+@dataclass
+class ExpectedSampleInfo:
+ chunks: list[SampleParsedChunk]
+ partial_sample: Sample
+
+
+def token_len(text: str) -> int:
+ return len(TOKENIZER(text, add_special_tokens=False)["input_ids"])
+
+
+def expected_chunk(text: str, loss_mask: int) -> SampleParsedChunk:
+ n = token_len(text)
+ log_probs = [-1 / 128 * i for i in range(n)] if loss_mask else [0.0] * n
+ return SampleParsedChunk(text, loss_mask, log_probs)
+
+
+def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]:
+ prompt_len = len(sample.tokens) - sample.response_length
+ response_tokens = sample.tokens[prompt_len:]
+ loss_mask = sample.loss_mask or []
+ log_probs = sample.rollout_log_probs or []
+
+ chunks = []
+ idx = 0
+ for mask_val, group in groupby(loss_mask):
+ group_len = len(list(group))
+ sli = slice(idx, idx + group_len)
+ chunks.append(
+ SampleParsedChunk(
+ tokens_decoded_str=tokenizer.decode(response_tokens[sli]),
+ loss_mask_value=mask_val,
+ rollout_log_probs=log_probs[sli],
+ )
+ )
+ idx += group_len
+ return chunks
+
+
+def expected_partial_sample(
+ *,
+ prompt: list[dict],
+ response: str,
+ response_length: int,
+ status: Sample.Status = Sample.Status.COMPLETED,
+) -> Sample:
+ return Sample(
+ prompt=prompt,
+ response=response,
+ response_length=response_length,
+ status=status,
+ tokens=[],
+ loss_mask=[],
+ rollout_log_probs=[],
+ weight_versions=[],
+ spec_info=Sample.SpecInfo(),
+ prefix_cache_info=Sample.PrefixCacheInfo(),
+ )
+
+
+def verify_samples(actual: Sample | list[Sample], expected: list[ExpectedSampleInfo]):
+ actual = listify(actual)
+ assert len(actual) == len(expected)
+
+ for actual_item, expected_item in zip(actual, expected, strict=True):
+ actual_chunks = parse_sample_into_chunks(actual_item, TOKENIZER)
+ assert actual_chunks == expected_item.chunks
+
+ actual_partial = replace(
+ deepcopy(actual_item),
+ tokens=[],
+ loss_mask=[],
+ rollout_log_probs=[],
+ prefix_cache_info=Sample.PrefixCacheInfo(),
+ )
+ assert actual_partial == expected_item.partial_sample
+
+
+def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_params: dict | None = None):
+ return run_generate(env, sample, sampling_params, variant=variant)
+
+
+def expected_request(input_ids: list[int], sampling_params: dict | None = None) -> dict:
+ return {
+ "input_ids": input_ids,
+ "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS,
+ "return_logprob": True,
+ "return_routed_experts": False,
+ }
+
+
+def expected_openai_request(messages: list[dict]) -> dict:
+ input_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
+ return {"messages": messages, "model": "default", "tools": SAMPLE_TOOLS, "input_ids": input_ids}
+
+
+SINGLE_TURN_PROMPT = [{"role": "user", "content": "What is 1+1?"}]
+SINGLE_TURN_RESPONSE = "The answer is 2."
+_SINGLE_TURN_PROMPT_TEXT = TOKENIZER.apply_chat_template(
+ SINGLE_TURN_PROMPT, tokenize=False, add_generation_prompt=True, tools=SAMPLE_TOOLS
+)
+SINGLE_TURN_PROMPT_TOKEN_IDS = TOKENIZER(_SINGLE_TURN_PROMPT_TEXT, add_special_tokens=False)["input_ids"]
+SINGLE_TURN_PROMPT_TOKEN_LEN = len(SINGLE_TURN_PROMPT_TOKEN_IDS)
+
+
+# ------------------------------------ tests ----------------------------------------
+
+
+class TestBasicMultiTurn:
+ def test_single_turn_no_tool_call(self, variant, generation_env):
+ generation_env.mock_server.process_fn = lambda _: ProcessResult(
+ text=SINGLE_TURN_RESPONSE, finish_reason="stop"
+ )
+
+ result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT))
+
+ if is_agentic_variant(variant):
+ assert result.requests == [expected_openai_request(SINGLE_TURN_PROMPT)]
+ else:
+ assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)]
+ verify_samples(
+ result.sample,
+ [
+ ExpectedSampleInfo(
+ chunks=[
+ SampleParsedChunk(
+ tokens_decoded_str=SINGLE_TURN_RESPONSE,
+ loss_mask_value=1,
+ rollout_log_probs=[-1 / 128 * i for i in range(6)],
+ ),
+ ],
+ partial_sample=expected_partial_sample(
+ prompt=SINGLE_TURN_PROMPT, response=SINGLE_TURN_RESPONSE, response_length=6
+ ),
+ ),
+ ],
+ )
+
+ def test_two_turns_with_tool_call(self, variant, generation_env):
+ generation_env.mock_server.process_fn = TwoTurnStub.process_fn
+
+ S = TwoTurnStub
+ result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT))
+
+ if is_agentic_variant(variant):
+ assert result.requests == [
+ expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN),
+ expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT),
+ ]
+ else:
+ assert result.requests == [
+ expected_request(S.FIRST_PROMPT_TOKEN_IDS),
+ expected_request(S.SECOND_PROMPT_TOKEN_IDS),
+ ]
+ if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"):
+ full_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + S.SECOND_RESPONSE
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[
+ expected_chunk(S.FIRST_RESPONSE, 1),
+ expected_chunk(S.FIRST_TOOL_RESPONSE, 0),
+ expected_chunk(S.SECOND_RESPONSE, 1),
+ ],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=full_response,
+ response_length=token_len(full_response),
+ ),
+ ),
+ ]
+ else:
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.FIRST_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.FIRST_RESPONSE,
+ response_length=token_len(S.FIRST_RESPONSE),
+ ),
+ ),
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.SECOND_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.SECOND_RESPONSE,
+ response_length=token_len(S.SECOND_RESPONSE),
+ ),
+ ),
+ ]
+ verify_samples(result.sample, expected)
+
+
+class TestExitConditions:
+ def test_partial_rollout_not_supported(self, variant, generation_env):
+ if is_agentic_variant(variant):
+ pytest.skip("agentic_tool_call does not check partial_rollout flag")
+ generation_env.args.partial_rollout = True
+
+ with pytest.raises(AssertionError, match="Partial rollout is not supported"):
+ _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT))
+
+ def test_abort_preserves_content(self, variant, generation_env):
+ if is_agentic_variant(variant):
+ pytest.skip("agentic_tool_call does not handle abort finish_reason")
+ generation_env.mock_server.process_fn = lambda _: ProcessResult(
+ text=SINGLE_TURN_RESPONSE, finish_reason="abort"
+ )
+
+ result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT))
+
+ assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)]
+ verify_samples(
+ result.sample,
+ [
+ ExpectedSampleInfo(
+ chunks=[
+ SampleParsedChunk(
+ tokens_decoded_str=SINGLE_TURN_RESPONSE,
+ loss_mask_value=1,
+ rollout_log_probs=[-1 / 128 * i for i in range(6)],
+ ),
+ ],
+ partial_sample=expected_partial_sample(
+ prompt=SINGLE_TURN_PROMPT,
+ response=SINGLE_TURN_RESPONSE,
+ response_length=6,
+ status=Sample.Status.ABORTED,
+ ),
+ ),
+ ],
+ )
+
+ def test_finish_reason_length_exits_and_preserves_content(self, variant, generation_env):
+ S = TwoTurnStub
+ generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="length")
+
+ result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT))
+
+ if is_agentic_variant(variant):
+ assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)]
+ else:
+ assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)]
+ verify_samples(
+ result.sample,
+ [
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.FIRST_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.FIRST_RESPONSE,
+ response_length=token_len(S.FIRST_RESPONSE),
+ status=Sample.Status.TRUNCATED,
+ ),
+ ),
+ ],
+ )
+
+ @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"generate_max_turns": 1}}], indirect=True)
+ def test_max_turns_reached(self, variant, generation_env):
+ S = TwoTurnStub
+ generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="stop")
+
+ result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT))
+
+ if is_agentic_variant(variant):
+ assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)]
+ else:
+ assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)]
+ if variant == "multi_turn_single_sample":
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[
+ expected_chunk(S.FIRST_RESPONSE, 1),
+ expected_chunk(S.FIRST_TOOL_RESPONSE, 0),
+ ],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE,
+ response_length=token_len(S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE),
+ ),
+ ),
+ ]
+ else:
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.FIRST_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.FIRST_RESPONSE,
+ response_length=token_len(S.FIRST_RESPONSE),
+ ),
+ ),
+ ]
+ verify_samples(result.sample, expected)
+
+
+class TestRespectMaxContextLen:
+ @pytest.mark.parametrize(
+ "generation_env", [{"args_kwargs": {"rollout_max_context_len": SINGLE_TURN_PROMPT_TOKEN_LEN}}], indirect=True
+ )
+ def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env):
+ if is_agentic_variant(variant):
+ pytest.skip("TODO: implement")
+ result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT))
+ assert result.requests == []
+ if variant == "multi_turn_single_sample":
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[],
+ partial_sample=expected_partial_sample(
+ prompt=SINGLE_TURN_PROMPT, response="", response_length=0, status=Sample.Status.TRUNCATED
+ ),
+ )
+ ]
+ else:
+ expected = []
+ verify_samples(result.sample, expected)
+
+ @pytest.mark.parametrize(
+ "generation_env",
+ [
+ {
+ "args_kwargs": {
+ "rollout_max_context_len": len(TwoTurnStub.FIRST_PROMPT_TOKEN_IDS)
+ + token_len(TwoTurnStub.FIRST_RESPONSE)
+ + token_len(TwoTurnStub.FIRST_TOOL_RESPONSE)
+ }
+ }
+ ],
+ indirect=True,
+ )
+ def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env):
+ if is_agentic_variant(variant):
+ pytest.skip("TODO: implement")
+ S = TwoTurnStub
+ generation_env.mock_server.process_fn = S.process_fn
+
+ result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT))
+
+ assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)]
+ if variant == "multi_turn_single_sample":
+ partial_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[
+ expected_chunk(S.FIRST_RESPONSE, 1),
+ expected_chunk(S.FIRST_TOOL_RESPONSE, 0),
+ ],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=partial_response,
+ response_length=token_len(partial_response),
+ status=Sample.Status.TRUNCATED,
+ ),
+ ),
+ ]
+ else:
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.FIRST_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.FIRST_RESPONSE,
+ response_length=token_len(S.FIRST_RESPONSE),
+ status=Sample.Status.TRUNCATED,
+ ),
+ ),
+ ]
+ verify_samples(result.sample, expected)
+
+ @pytest.mark.parametrize(
+ "generation_env,expected_max_new_tokens",
+ [
+ (
+ {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 10}},
+ 10,
+ ),
+ (
+ {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 100}},
+ 64,
+ ),
+ ],
+ indirect=["generation_env"],
+ )
+ def test_second_turn_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens):
+ if is_agentic_variant(variant):
+ pytest.skip("TODO: implement")
+ S = TwoTurnStub
+ generation_env.mock_server.process_fn = S.process_fn
+
+ result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT))
+
+ assert len(result.requests) >= 2
+ assert result.requests[1]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens
+ assert result.requests[1]["sampling_params"]["temperature"] == DEFAULT_SAMPLING_PARAMS["temperature"]
+
+
+class TestThreeTurn:
+ """Need to test 3-turn case besides 2-turn, because e.g. merge_samples may behave differently."""
+
+ def test_three_turns_with_sequential_tool_calls(self, variant, generation_env):
+ generation_env.mock_server.process_fn = ThreeTurnStub.process_fn
+
+ S = ThreeTurnStub
+ result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT))
+
+ if is_agentic_variant(variant):
+ assert result.requests == [
+ expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN),
+ expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT),
+ expected_openai_request(S.OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT),
+ ]
+ else:
+ assert result.requests == [
+ expected_request(S.FIRST_PROMPT_TOKEN_IDS),
+ expected_request(S.SECOND_PROMPT_TOKEN_IDS),
+ expected_request(S.THIRD_PROMPT_TOKEN_IDS),
+ ]
+ if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"):
+ full_response = (
+ S.FIRST_RESPONSE
+ + S.FIRST_TOOL_RESPONSE
+ + S.SECOND_RESPONSE
+ + S.SECOND_TOOL_RESPONSE
+ + S.THIRD_RESPONSE
+ )
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[
+ expected_chunk(S.FIRST_RESPONSE, 1),
+ expected_chunk(S.FIRST_TOOL_RESPONSE, 0),
+ expected_chunk(S.SECOND_RESPONSE, 1),
+ expected_chunk(S.SECOND_TOOL_RESPONSE, 0),
+ expected_chunk(S.THIRD_RESPONSE, 1),
+ ],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=full_response,
+ response_length=token_len(full_response),
+ ),
+ ),
+ ]
+ else:
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.FIRST_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.FIRST_RESPONSE,
+ response_length=token_len(S.FIRST_RESPONSE),
+ ),
+ ),
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.SECOND_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.SECOND_RESPONSE,
+ response_length=token_len(S.SECOND_RESPONSE),
+ ),
+ ),
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.THIRD_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.THIRD_RESPONSE,
+ response_length=token_len(S.THIRD_RESPONSE),
+ ),
+ ),
+ ]
+ verify_samples(result.sample, expected)
+
+
+class TestFourTurnWithThink:
+ def test_four_turns_with_think_prefix(self, variant, generation_env):
+ generation_env.mock_server.process_fn = ThinkingThreeTurnStub.process_fn
+
+ S = ThinkingThreeTurnStub
+ result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT))
+
+ if is_agentic_variant(variant):
+ assert result.requests == [
+ expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN),
+ expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT),
+ expected_openai_request(S.OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT),
+ ]
+ else:
+ assert result.requests == [
+ expected_request(S.FIRST_PROMPT_TOKEN_IDS),
+ expected_request(S.SECOND_PROMPT_TOKEN_IDS),
+ expected_request(S.THIRD_PROMPT_TOKEN_IDS),
+ ]
+ if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"):
+ full_response = (
+ S.FIRST_RESPONSE
+ + S.FIRST_TOOL_RESPONSE
+ + S.SECOND_RESPONSE
+ + S.SECOND_TOOL_RESPONSE
+ + S.THIRD_RESPONSE
+ )
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[
+ expected_chunk(S.FIRST_RESPONSE, 1),
+ expected_chunk(S.FIRST_TOOL_RESPONSE, 0),
+ expected_chunk(S.SECOND_RESPONSE, 1),
+ expected_chunk(S.SECOND_TOOL_RESPONSE, 0),
+ expected_chunk(S.THIRD_RESPONSE, 1),
+ ],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=full_response,
+ response_length=token_len(full_response),
+ ),
+ ),
+ ]
+ else:
+ expected = [
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.FIRST_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.FIRST_RESPONSE,
+ response_length=token_len(S.FIRST_RESPONSE),
+ ),
+ ),
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.SECOND_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.SECOND_RESPONSE,
+ response_length=token_len(S.SECOND_RESPONSE),
+ ),
+ ),
+ ExpectedSampleInfo(
+ chunks=[expected_chunk(S.THIRD_RESPONSE, 1)],
+ partial_sample=expected_partial_sample(
+ prompt=S.PROMPT,
+ response=S.THIRD_RESPONSE,
+ response_length=token_len(S.THIRD_RESPONSE),
+ ),
+ ),
+ ]
+ verify_samples(result.sample, expected)
+
+ messages_without_think = deepcopy(S.OPENAI_MESSAGES_FOURTH_TURN_FROM_CLIENT)
+ messages_without_think[1]["content"] = messages_without_think[1]["content"].replace(S.THINK_PREFIX, "")
+ token_ids_with_think = TOKENIZER.apply_chat_template(
+ S.OPENAI_MESSAGES_FOURTH_TURN_FROM_CLIENT, tokenize=True, add_generation_prompt=True
+ )
+ token_ids_without_think = TOKENIZER.apply_chat_template(
+ messages_without_think, tokenize=True, add_generation_prompt=True
+ )
+ assert token_ids_with_think == token_ids_without_think
+
+
+class TestRoutedExpertsMultiTurn:
+ @pytest.mark.parametrize(
+ "generation_env",
+ [
+ {
+ "args_kwargs": {
+ "use_rollout_routing_replay": True,
+ }
+ }
+ ],
+ indirect=True,
+ )
+ def test_two_turns_routed_experts(self, variant, generation_env):
+ if is_agentic_variant(variant):
+ pytest.skip("TODO: implement")
+
+ S = TwoTurnStub
+ num_layers, moe_router_topk = 2, 4
+ generation_env.args.num_layers = num_layers
+ generation_env.args.moe_router_topk = moe_router_topk
+
+ def make_routed_experts(prompt_token_ids, response_text):
+ total_tokens = len(prompt_token_ids) + token_len(response_text)
+ routed_experts_len = total_tokens - 1
+ return np.arange(routed_experts_len * num_layers * moe_router_topk, dtype=np.int32).reshape(
+ routed_experts_len, num_layers, moe_router_topk
+ )
+
+ first_routed_experts = make_routed_experts(S.FIRST_PROMPT_TOKEN_IDS, S.FIRST_RESPONSE)
+ second_routed_experts = make_routed_experts(S.SECOND_PROMPT_TOKEN_IDS, S.SECOND_RESPONSE)
+
+ def process_fn(prompt: str) -> ProcessResult:
+ if prompt == S.FIRST_PROMPT:
+ text, routed_experts = S.FIRST_RESPONSE, first_routed_experts
+ elif prompt == S.SECOND_PROMPT:
+ text, routed_experts = S.SECOND_RESPONSE, second_routed_experts
+ else:
+ raise ValueError(f"Unexpected prompt: {prompt}")
+ return ProcessResult(
+ text=text,
+ finish_reason="stop",
+ meta_info=ProcessResultMetaInfo(
+ routed_experts=pybase64.b64encode(routed_experts.tobytes()).decode("ascii")
+ ),
+ )
+
+ generation_env.mock_server.process_fn = process_fn
+ result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT), DEFAULT_SAMPLING_PARAMS)
+
+ sample = result.sample[-1] if isinstance(result.sample, list) else result.sample
+ assert sample.rollout_routed_experts is not None
+ assert sample.rollout_routed_experts.shape == second_routed_experts.shape
+ np.testing.assert_array_equal(sample.rollout_routed_experts, second_routed_experts)
+ assert len(sample.tokens) - 1 == second_routed_experts.shape[0]
diff --git a/tests/fast/rollout/generate_hub/test_single_turn.py b/tests/fast/rollout/generate_hub/test_single_turn.py
new file mode 100644
index 000000000..a58e6fb3c
--- /dev/null
+++ b/tests/fast/rollout/generate_hub/test_single_turn.py
@@ -0,0 +1,424 @@
+import numpy as np
+import pybase64
+import pytest
+import torch
+from PIL import Image
+from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate
+from transformers import AutoProcessor
+
+from miles.utils.processing_utils import encode_image_for_rollout_engine
+from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo
+from miles.utils.types import Sample
+
+_ = generation_env
+
+# ------------------------------------ fixtures and consts ----------------------------------------
+
+
+MODEL_NAME = "Qwen/Qwen3-0.6B"
+PROMPT = "What is 1+7?"
+PROMPT_TOKENS = [3838, 374, 220, 16, 10, 22, 30]
+PROMPT_TOKEN_LEN = len(PROMPT_TOKENS)
+RESPONSE_TOKENS = [59, 79075, 90, 23, 92]
+RESPONSE_TEXT = "\\boxed{8}"
+RESPONSE_LOG_PROBS = [-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125]
+SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7}
+DEFAULT_MAX_NEW_TOKENS = SAMPLING_PARAMS["max_new_tokens"]
+
+
+@pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample", "multi_turn_multi_samples"])
+def variant(request):
+ return request.param
+
+
+def expected_request(
+ variant: str,
+ *,
+ input_ids: list[int] | None = None,
+ sampling_params: dict | None = None,
+ return_routed_experts: bool = False,
+ image_data: list[str] | None = None,
+) -> dict:
+ result = {
+ "input_ids": input_ids or PROMPT_TOKENS,
+ "sampling_params": sampling_params or SAMPLING_PARAMS,
+ "return_logprob": True,
+ }
+ if variant in ("single_turn", "multi_turn_single_sample", "multi_turn_multi_samples") or return_routed_experts:
+ result["return_routed_experts"] = return_routed_experts
+ if image_data is not None:
+ result["image_data"] = image_data
+ return result
+
+
+class _Unset:
+ pass
+
+
+_UNSET = _Unset()
+
+
+def expected_sample(
+ variant: str,
+ *,
+ prompt: str = PROMPT,
+ response: str = RESPONSE_TEXT,
+ response_length: int = 5,
+ tokens: list[int] | None | _Unset = _UNSET,
+ rollout_log_probs: list[float] | None | _Unset = _UNSET,
+ status: Sample.Status = Sample.Status.COMPLETED,
+ cached_tokens: int = 0,
+ prompt_tokens: int = 7,
+ weight_versions: list[str] | None = None,
+ rollout_routed_experts: np.ndarray | None = None,
+ spec_info: Sample.SpecInfo | None = None,
+ multimodal_inputs: dict | None = None,
+ multimodal_train_inputs: dict | None = None,
+ loss_mask: list[int] | None | _Unset = _UNSET,
+) -> Sample:
+ actual_response_length = response_length if response_length is not None else len(RESPONSE_TOKENS)
+ if isinstance(loss_mask, _Unset):
+ loss_mask = (
+ [1] * actual_response_length
+ if variant in ("multi_turn_single_sample", "multi_turn_multi_samples")
+ else None
+ )
+
+ return Sample(
+ group_index=None,
+ index=None,
+ prompt=prompt,
+ tokens=PROMPT_TOKENS + RESPONSE_TOKENS if isinstance(tokens, _Unset) else tokens,
+ multimodal_inputs=multimodal_inputs,
+ multimodal_train_inputs=multimodal_train_inputs,
+ response=response,
+ response_length=response_length,
+ label=None,
+ reward=None,
+ loss_mask=loss_mask,
+ weight_versions=weight_versions or [],
+ rollout_log_probs=RESPONSE_LOG_PROBS if isinstance(rollout_log_probs, _Unset) else rollout_log_probs,
+ rollout_routed_experts=rollout_routed_experts,
+ remove_sample=False,
+ status=status,
+ metadata={},
+ train_metadata=None,
+ non_generation_time=0.0,
+ spec_info=spec_info or Sample.SpecInfo(),
+ prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=cached_tokens, total_prompt_tokens=prompt_tokens),
+ )
+
+
+def _make_sample(tokens=None, response="", response_length=0, status=Sample.Status.PENDING, multimodal_inputs=None):
+ return make_sample(
+ prompt=PROMPT,
+ tokens=tokens,
+ response=response,
+ response_length=response_length,
+ status=status,
+ multimodal_inputs=multimodal_inputs,
+ )
+
+
+def _run_generate(variant: str, env: GenerateEnv, sample: Sample | None = None, sampling_params: dict | None = None):
+ return run_generate(env, sample or _make_sample(), sampling_params or SAMPLING_PARAMS, variant=variant)
+
+
+# ------------------------------------ tests ----------------------------------------
+
+
+class TestBasicGeneration:
+ def test_basic_generation(self, variant, generation_env):
+ result = _run_generate(variant, generation_env)
+ assert result.requests == [expected_request(variant)]
+ assert listify(result.sample) == [expected_sample(variant)]
+
+
+class TestResumedSingleTurn:
+ def test_two_consecutive_calls_on_same_sample(self, variant, generation_env):
+ if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"):
+ pytest.skip("not tested yet")
+ partial_text = "\\boxed"
+ partial_tokens = [59, 79075]
+ partial_log_probs = [-0.0, -0.0078125]
+
+ remaining_text = "{8}"
+ remaining_tokens = [90, 23, 92]
+ remaining_log_probs = [-0.0, -0.0078125, -0.015625]
+
+ generation_env.mock_server.process_fn = lambda _: ProcessResult(text=partial_text, finish_reason="abort")
+ sample = _make_sample()
+ result1 = _run_generate(variant, generation_env, sample)
+ assert result1.requests == [expected_request(variant)]
+ assert result1.sample == expected_sample(
+ variant,
+ response=partial_text,
+ response_length=2,
+ tokens=PROMPT_TOKENS + partial_tokens,
+ rollout_log_probs=partial_log_probs,
+ status=Sample.Status.ABORTED,
+ )
+
+ generation_env.mock_server.process_fn = lambda _: ProcessResult(text=remaining_text, finish_reason="stop")
+ result2 = _run_generate(variant, generation_env, result1.sample)
+ tokens_after_turn1 = PROMPT_TOKENS + partial_tokens
+ assert result2.requests == [
+ expected_request(
+ variant,
+ input_ids=tokens_after_turn1,
+ sampling_params={"max_new_tokens": 14, "temperature": 0.7},
+ )
+ ]
+ assert result2.sample == expected_sample(
+ variant,
+ response=partial_text + remaining_text,
+ response_length=2 + 3,
+ tokens=tokens_after_turn1 + remaining_tokens,
+ rollout_log_probs=partial_log_probs + remaining_log_probs,
+ prompt_tokens=len(PROMPT_TOKENS) + len(tokens_after_turn1),
+ status=Sample.Status.COMPLETED,
+ )
+
+
+class TestFinishReason:
+ @pytest.mark.parametrize(
+ "generation_env,expected_status",
+ [
+ ({"process_fn_kwargs": {"finish_reason": "stop"}}, Sample.Status.COMPLETED),
+ ({"process_fn_kwargs": {"finish_reason": "length"}}, Sample.Status.TRUNCATED),
+ ({"process_fn_kwargs": {"finish_reason": "abort"}}, Sample.Status.ABORTED),
+ ],
+ indirect=["generation_env"],
+ )
+ def test_finish_reason_sets_status(self, variant, generation_env, expected_status):
+ result = _run_generate(variant, generation_env)
+ assert result.requests == [expected_request(variant)]
+ assert listify(result.sample) == [expected_sample(variant, status=expected_status)]
+
+
+class TestRoutedExperts:
+ @pytest.mark.parametrize(
+ "generation_env",
+ [
+ {
+ "args_kwargs": {"use_rollout_routing_replay": True},
+ "process_fn_kwargs": {"routed_experts": "placeholder"},
+ }
+ ],
+ indirect=True,
+ )
+ def test_routed_experts_enabled_and_parsed(self, variant, generation_env):
+ num_layers, moe_router_topk = 2, 4
+ num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS)
+ routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape(
+ num_tokens - 1, num_layers, moe_router_topk
+ )
+
+ generation_env.args.num_layers = num_layers
+ generation_env.args.moe_router_topk = moe_router_topk
+ routed_experts_str = pybase64.b64encode(routed_experts_array.tobytes()).decode("ascii")
+ generation_env.mock_server.process_fn = lambda _: ProcessResult(
+ text=RESPONSE_TEXT,
+ finish_reason="stop",
+ meta_info=ProcessResultMetaInfo(routed_experts=routed_experts_str),
+ )
+
+ result = _run_generate(variant, generation_env)
+ assert result.requests == [expected_request(variant, return_routed_experts=True)]
+ sample = result.sample[0] if isinstance(result.sample, list) else result.sample
+ assert sample.rollout_routed_experts is not None
+ assert sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk)
+ np.testing.assert_array_equal(sample.rollout_routed_experts, routed_experts_array)
+
+
+class TestMetaInfo:
+ @pytest.mark.parametrize(
+ "generation_env", [{"process_fn_kwargs": {"cached_tokens": 3, "weight_version": "v1.0"}}], indirect=True
+ )
+ def test_meta_info_fields_updated(self, variant, generation_env):
+ result = _run_generate(variant, generation_env)
+ assert result.requests == [expected_request(variant)]
+ assert listify(result.sample) == [expected_sample(variant, cached_tokens=3, weight_versions=["v1.0"])]
+
+ @pytest.mark.parametrize(
+ "generation_env",
+ [
+ {
+ "args_kwargs": {"sglang_speculative_algorithm": "EAGLE"},
+ "process_fn_kwargs": {"spec_accept_token_num": 10, "spec_draft_token_num": 15, "spec_verify_ct": 3},
+ }
+ ],
+ indirect=True,
+ )
+ def test_spec_info_updated(self, variant, generation_env):
+ result = _run_generate(variant, generation_env)
+ assert result.requests == [expected_request(variant)]
+ assert listify(result.sample) == [
+ expected_sample(
+ variant,
+ spec_info=Sample.SpecInfo(
+ spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5
+ ),
+ )
+ ]
+
+
+class TestInputStatusValidation:
+ @pytest.mark.parametrize("status", [Sample.Status.PENDING, Sample.Status.ABORTED])
+ def test_allowed_statuses(self, variant, generation_env, status):
+ result = _run_generate(variant, generation_env, _make_sample(status=status))
+ assert result.requests == [expected_request(variant)]
+ assert listify(result.sample) == [expected_sample(variant)]
+
+ @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED])
+ def test_rejected_statuses(self, variant, generation_env, status):
+ if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"):
+ pytest.skip("not tested yet")
+ with pytest.raises(AssertionError):
+ _run_generate(variant, generation_env, _make_sample(status=status))
+
+
+class TestPayloadStructure:
+ def test_sampling_params_passed_through(self, variant, generation_env):
+ result = _run_generate(
+ variant, generation_env, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}
+ )
+ assert result.requests == [
+ expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9})
+ ]
+ assert listify(result.sample) == [expected_sample(variant)]
+
+
+class TestBoundaryConditions:
+ def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env):
+ if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"):
+ pytest.skip("not tested yet")
+ existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110))
+ sample = _make_sample(tokens=existing_tokens, response="x" * 10, response_length=10)
+
+ result = _run_generate(variant, generation_env, sample, {"max_new_tokens": 10, "temperature": 0.7})
+ assert result.requests == []
+ assert result.sample == expected_sample(
+ variant,
+ response="x" * 10,
+ response_length=10,
+ tokens=existing_tokens,
+ rollout_log_probs=None,
+ status=Sample.Status.TRUNCATED,
+ prompt_tokens=0,
+ )
+
+ @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"rollout_max_context_len": 5}}], indirect=True)
+ def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env):
+ if variant == "old_sglang_rollout":
+ pytest.skip("old_sglang_rollout does not support rollout_max_context_len")
+ if variant == "multi_turn_multi_samples":
+ pytest.skip("multi_turn_multi_samples returns empty list when first turn fails")
+ result = _run_generate(variant, generation_env)
+ assert result.requests == []
+ tokens = PROMPT_TOKENS if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") else []
+ assert listify(result.sample) == [
+ expected_sample(
+ variant,
+ response="",
+ response_length=0,
+ tokens=tokens,
+ rollout_log_probs=None,
+ status=Sample.Status.TRUNCATED,
+ prompt_tokens=0,
+ loss_mask=None if variant == "multi_turn_single_sample" else _UNSET,
+ )
+ ]
+
+ @pytest.mark.parametrize(
+ "generation_env,expected_max_new_tokens",
+ [
+ ({"args_kwargs": {"rollout_max_context_len": 10}}, 10 - PROMPT_TOKEN_LEN),
+ ({"args_kwargs": {"rollout_max_context_len": 8}}, 8 - PROMPT_TOKEN_LEN),
+ ({"args_kwargs": {"rollout_max_context_len": 100}}, DEFAULT_MAX_NEW_TOKENS),
+ ],
+ indirect=["generation_env"],
+ )
+ def test_moderate_length_input_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens):
+ if variant == "old_sglang_rollout":
+ pytest.skip("old_sglang_rollout does not support rollout_max_context_len")
+ result = _run_generate(variant, generation_env)
+ assert len(result.requests) == 1
+ assert result.requests[0]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens
+ assert result.requests[0]["sampling_params"]["temperature"] == SAMPLING_PARAMS["temperature"]
+ assert listify(result.sample) == [expected_sample(variant)]
+
+ @pytest.mark.parametrize(
+ "generation_env",
+ [{"args_kwargs": {"rollout_max_context_len": PROMPT_TOKEN_LEN}}],
+ indirect=True,
+ )
+ def test_adjusted_max_new_tokens_zero_returns_truncated(self, variant, generation_env):
+ if variant == "old_sglang_rollout":
+ pytest.skip("old_sglang_rollout does not support rollout_max_context_len")
+ if variant == "multi_turn_multi_samples":
+ pytest.skip("multi_turn_multi_samples returns empty list when first turn fails")
+ result = _run_generate(variant, generation_env)
+ assert result.requests == []
+ tokens = PROMPT_TOKENS if variant == "multi_turn_single_sample" else []
+ assert listify(result.sample) == [
+ expected_sample(
+ variant,
+ response="",
+ response_length=0,
+ tokens=tokens,
+ rollout_log_probs=None,
+ status=Sample.Status.TRUNCATED,
+ prompt_tokens=0,
+ loss_mask=None if variant == "multi_turn_single_sample" else _UNSET,
+ )
+ ]
+
+
+class TestEmptyResponse:
+ @pytest.mark.parametrize("generation_env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True)
+ def test_empty_response(self, variant, generation_env):
+ result = _run_generate(variant, generation_env)
+ assert result.requests == [expected_request(variant)]
+ assert listify(result.sample) == [
+ expected_sample(variant, response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[])
+ ]
+
+
+VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct"
+
+
+class TestMultimodal:
+ @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True)
+ def test_multimodal_inputs_processed(self, variant, generation_env):
+ if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"):
+ pytest.skip("not tested yet")
+ test_image = Image.new("RGB", (64, 64), color="red")
+ multimodal_inputs = {"images": [test_image]}
+ processor = AutoProcessor.from_pretrained(VLM_MODEL_NAME, trust_remote_code=True)
+ expected_mti = {
+ k: v
+ for k, v in processor(text=PROMPT, **multimodal_inputs).items()
+ if k not in ["input_ids", "attention_mask"]
+ }
+
+ result = _run_generate(variant, generation_env, _make_sample(multimodal_inputs=multimodal_inputs))
+
+ assert result.requests == [
+ expected_request(
+ variant,
+ input_ids=PROMPT_TOKENS,
+ image_data=[encode_image_for_rollout_engine(test_image)],
+ )
+ ]
+ actual_mti = result.sample.multimodal_train_inputs
+ assert actual_mti is not None
+ assert set(actual_mti.keys()) == set(expected_mti.keys())
+ assert torch.all(actual_mti["pixel_values"] == expected_mti["pixel_values"])
+ assert torch.all(actual_mti["image_grid_thw"] == expected_mti["image_grid_thw"])
+ assert result.sample == expected_sample(
+ variant,
+ tokens=PROMPT_TOKENS + RESPONSE_TOKENS,
+ multimodal_inputs=multimodal_inputs,
+ multimodal_train_inputs=actual_mti,
+ )
diff --git a/tests/fast/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py
new file mode 100644
index 000000000..0f2305e75
--- /dev/null
+++ b/tests/fast/rollout/generate_hub/test_tool_call_utils.py
@@ -0,0 +1,99 @@
+import pytest
+
+from miles.rollout.generate_utils.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses
+
+TOOL_CALL_TEST_MODELS = [
+ "Qwen/Qwen2.5-0.5B-Instruct",
+ "Qwen/Qwen3-0.6B",
+ "Qwen/Qwen3-4B-Instruct-2507",
+ "Qwen/Qwen3-Coder-30B-A3B-Instruct",
+ # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI
+ "mistralai/Mistral-7B-Instruct-v0.3",
+ "deepseek-ai/DeepSeek-V3",
+ "stepfun-ai/step3",
+ "MiniMaxAI/MiniMax-M2",
+ "internlm/internlm3-8b-instruct",
+ "THUDM/glm-4-9b-chat",
+ "moonshotai/Kimi-K2-Instruct",
+ "XiaomiMiMo/MiMo-7B-RL",
+]
+
+SINGLE_TOOL_CALL_ONLY_MODELS = [
+ # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo
+]
+
+# Models where tokenize->decode produces extra whitespace vs direct string diff
+TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS = [
+ "THUDM/glm-4-9b-chat",
+]
+
+SAMPLE_TOOL_RESPONSES = [
+ {
+ "role": "tool",
+ "tool_call_id": "call00000",
+ "content": '{"year": 2026}',
+ "name": "get_year",
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call00001",
+ "content": '{"temperature": 25}',
+ "name": "get_temperature",
+ },
+]
+
+
+class TestTokenizeToolResponses:
+ @pytest.mark.parametrize("model_name", ["Qwen/Qwen3-0.6B"])
+ def test_snapshot(self, model_name):
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ token_ids = tokenize_tool_responses(SAMPLE_TOOL_RESPONSES, tokenizer)
+ decoded = tokenizer.decode(token_ids)
+
+ assert decoded == (
+ "<|im_start|>user\n"
+ "\n"
+ '{"year": 2026}\n'
+ "\n"
+ "\n"
+ '{"temperature": 25}\n'
+ "<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+
+ @pytest.mark.parametrize("num_tools", [1, 2])
+ @pytest.mark.parametrize("model_name", TOOL_CALL_TEST_MODELS)
+ def test_tokenize_tool_responses(self, model_name, num_tools):
+ if num_tools > 1 and model_name in SINGLE_TOOL_CALL_ONLY_MODELS:
+ pytest.skip(f"{model_name} only supports single tool call")
+
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+
+ tool_responses = SAMPLE_TOOL_RESPONSES[:num_tools]
+ assert len(tool_responses) == num_tools
+
+ actual_token_ids = tokenize_tool_responses(tool_responses, tokenizer)
+ actual_str = tokenizer.decode(actual_token_ids)
+
+ dummy_assistant = _build_dummy_assistant(tool_responses)
+ base_messages = [_DUMMY_USER, dummy_assistant]
+ expected_str = self._compute_chat_template_diff(base_messages, tool_responses, tokenizer)
+
+ if model_name in TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS:
+ # Some models produce whitespace differences between tokenize->decode and direct string diff
+ actual_str = actual_str.replace(" ", "")
+ expected_str = expected_str.replace(" ", "")
+
+ assert actual_str == expected_str, f"{model_name=}"
+
+ @staticmethod
+ def _compute_chat_template_diff(base_messages, extra_messages, tokenizer) -> str:
+ text_with = tokenizer.apply_chat_template(
+ base_messages + extra_messages, tokenize=False, add_generation_prompt=True
+ )
+ text_without = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=False)
+ return text_with[len(text_without) :]
diff --git a/tests/fast/rollout/generate_utils/__init__.py b/tests/fast/rollout/generate_utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/rollout/generate_utils/test_sample_utils.py b/tests/fast/rollout/generate_utils/test_sample_utils.py
new file mode 100644
index 000000000..c53fbbb56
--- /dev/null
+++ b/tests/fast/rollout/generate_utils/test_sample_utils.py
@@ -0,0 +1,156 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from miles.rollout.generate_utils.sample_utils import _merge_sample_pair
+from miles.utils.types import Sample
+
+
+@pytest.fixture
+def mock_tokenizer():
+ tokenizer = MagicMock()
+ tokenizer.decode = lambda tokens: f""
+ return tokenizer
+
+
+def make_sample(
+ prompt="test_prompt",
+ tokens=None,
+ response="",
+ response_length=0,
+ loss_mask=None,
+ rollout_log_probs=None,
+ status=Sample.Status.COMPLETED,
+ label="test_label",
+ reward=1.0,
+ index=0,
+ group_index=0,
+):
+ return Sample(
+ prompt=prompt,
+ tokens=tokens or [],
+ response=response,
+ response_length=response_length,
+ loss_mask=loss_mask,
+ rollout_log_probs=rollout_log_probs,
+ status=status,
+ label=label,
+ reward=reward,
+ index=index,
+ group_index=group_index,
+ )
+
+
+class TestMergeSamples:
+ def test_basic_merge(self, mock_tokenizer):
+ a = make_sample(
+ tokens=[1, 2, 3, 10, 11, 12],
+ response="response1",
+ response_length=3,
+ loss_mask=[1, 1, 1],
+ rollout_log_probs=[-0.1, -0.2, -0.3],
+ )
+ b = make_sample(
+ tokens=[1, 2, 3, 10, 11, 12, 20, 21, 30, 31, 32],
+ response="response2",
+ response_length=3,
+ loss_mask=[1, 1, 1],
+ rollout_log_probs=[-0.4, -0.5, -0.6],
+ status=Sample.Status.TRUNCATED,
+ )
+
+ merged = _merge_sample_pair(a, b, mock_tokenizer)
+
+ assert merged.tokens == b.tokens
+ assert merged.response_length == 3 + 2 + 3
+ assert merged.loss_mask == [1, 1, 1, 0, 0, 1, 1, 1]
+ assert merged.rollout_log_probs == [-0.1, -0.2, -0.3, 0.0, 0.0, -0.4, -0.5, -0.6]
+ assert merged.prompt == a.prompt
+ assert merged.status == b.status
+ assert merged.label == a.label
+ assert merged.index == a.index
+ assert merged.group_index == a.group_index
+ assert "response1" in merged.response
+ assert "response2" in merged.response
+ assert "" in merged.response
+
+ def test_loss_mask_none_defaults_to_all_ones(self, mock_tokenizer):
+ a = make_sample(
+ tokens=[1, 2, 10],
+ response_length=1,
+ loss_mask=None,
+ rollout_log_probs=None,
+ )
+ b = make_sample(
+ tokens=[1, 2, 10, 20, 30],
+ response_length=1,
+ loss_mask=None,
+ rollout_log_probs=None,
+ )
+
+ merged = _merge_sample_pair(a, b, mock_tokenizer)
+
+ assert merged.loss_mask == [1, 0, 1]
+ assert merged.rollout_log_probs == [0.0, 0.0, 0.0]
+
+ def test_tokens_prefix_mismatch_raises(self, mock_tokenizer):
+ a = make_sample(
+ tokens=[1, 2, 3],
+ response_length=1,
+ loss_mask=[1],
+ )
+ b = make_sample(
+ tokens=[1, 2, 99, 20, 30],
+ response_length=1,
+ loss_mask=[1],
+ )
+
+ with pytest.raises(AssertionError, match="b.tokens must start with a.tokens"):
+ _merge_sample_pair(a, b, mock_tokenizer)
+
+ def test_field_mismatch_raises(self, mock_tokenizer):
+ a = make_sample(
+ tokens=[1, 2, 10],
+ response_length=1,
+ loss_mask=[1],
+ index=0,
+ )
+ b = make_sample(
+ tokens=[1, 2, 10, 20, 30],
+ response_length=1,
+ loss_mask=[1],
+ index=1,
+ )
+
+ with pytest.raises(AssertionError, match="index mismatch"):
+ _merge_sample_pair(a, b, mock_tokenizer)
+
+ def test_obs_len_invalid_raises(self, mock_tokenizer):
+ a = make_sample(
+ tokens=[1, 2, 10],
+ response_length=1,
+ loss_mask=[1],
+ )
+ b = make_sample(
+ tokens=[1, 2, 10, 30],
+ response_length=1,
+ loss_mask=[1],
+ )
+
+ with pytest.raises(AssertionError, match="obs_len must be > 0"):
+ _merge_sample_pair(a, b, mock_tokenizer)
+
+ def test_sample_validate_fails_raises(self, mock_tokenizer):
+ a = make_sample(
+ tokens=[1, 2, 10, 11],
+ response_length=2,
+ loss_mask=[1],
+ )
+ b = make_sample(
+ tokens=[1, 2, 10, 11, 20, 30],
+ response_length=1,
+ loss_mask=[1],
+ )
+
+ with pytest.raises(AssertionError, match="loss_mask length"):
+ _merge_sample_pair(a, b, mock_tokenizer)
diff --git a/tests/fast/rollout/inference_rollout/__init__.py b/tests/fast/rollout/inference_rollout/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/rollout/inference_rollout/conftest.py b/tests/fast/rollout/inference_rollout/conftest.py
new file mode 100644
index 000000000..ca47edeeb
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/conftest.py
@@ -0,0 +1,45 @@
+from unittest.mock import patch
+
+import pytest
+
+from miles.utils.arguments import parse_args
+
+
+def _build_mock_args(extra_argv: list[str] | None = None):
+ argv = [
+ "pytest",
+ "--train-backend",
+ "fsdp",
+ "--rollout-batch-size",
+ "2",
+ "--n-samples-per-prompt",
+ "1",
+ "--num-rollout",
+ "1",
+ "--rollout-num-gpus",
+ "4",
+ "--rollout-num-gpus-per-engine",
+ "2",
+ "--hf-checkpoint",
+ "Qwen/Qwen3-0.6B",
+ "--prompt-data",
+ "/dev/null",
+ "--input-key",
+ "input",
+ "--label-key",
+ "label",
+ "--rm-type",
+ "math",
+ "--use-miles-router",
+ "--sglang-router-ip",
+ "127.0.0.1",
+ "--sglang-router-port",
+ "30000",
+ ] + (extra_argv or [])
+ with patch("sys.argv", argv):
+ return parse_args()
+
+
+@pytest.fixture
+def mock_args():
+ return _build_mock_args()
diff --git a/tests/fast/rollout/inference_rollout/integration/__init__.py b/tests/fast/rollout/inference_rollout/integration/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/rollout/inference_rollout/integration/test_basic.py b/tests/fast/rollout/inference_rollout/integration/test_basic.py
new file mode 100644
index 000000000..5b791829d
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/test_basic.py
@@ -0,0 +1,69 @@
+import pytest
+from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant
+from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig
+from tests.fast.rollout.inference_rollout.integration.utils import (
+ MODULAR_ROLLOUT_BASE_ARGV,
+ expected_sample,
+ load_and_call_train,
+)
+
+from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput
+from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function
+
+_VARIANTS = [
+ pytest.param(
+ RolloutEnvConfig(
+ extra_argv=[
+ "--rollout-function-path",
+ "miles.rollout.sglang_rollout.generate_rollout",
+ "--eval-function-path",
+ "miles.rollout.sglang_rollout.generate_rollout",
+ "--custom-generate-function-path",
+ "miles.rollout.sglang_rollout.generate",
+ ]
+ ),
+ id="old_rollout_old_generate",
+ ),
+ pytest.param(
+ RolloutEnvConfig(
+ extra_argv=[
+ "--rollout-function-path",
+ "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn",
+ "--custom-generate-function-path",
+ "miles.rollout.sglang_rollout.generate",
+ ]
+ ),
+ id="new_rollout_old_generate",
+ ),
+ pytest.param(
+ RolloutEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant("single_turn")),
+ id="new_rollout_new_generate",
+ ),
+]
+
+
+@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True)
+def test_train(rollout_env):
+ env = rollout_env
+ out = load_and_call_train(env.args, env.data_source)
+
+ assert len(out.samples) == env.args.rollout_batch_size
+ group = out.samples[0]
+ assert len(group) == env.args.n_samples_per_prompt
+ assert group[0] == expected_sample(group_index=0)
+
+
+@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True)
+def test_eval(rollout_env):
+ env = rollout_env
+ fn = load_rollout_function(
+ RolloutFnConstructorInput(args=env.args, data_source=env.data_source), env.args.eval_function_path
+ )
+ out = call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0))
+
+ assert "toy" in out.data
+ rewards = out.data["toy"]["rewards"]
+ samples = out.data["toy"]["samples"]
+ assert len(rewards) == len(samples) == env.args.n_samples_per_eval_prompt
+ assert rewards[0] == 1
+ assert samples[0] == expected_sample(group_index=None)
diff --git a/tests/fast/rollout/inference_rollout/integration/test_deterministic.py b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py
new file mode 100644
index 000000000..69a235911
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py
@@ -0,0 +1,37 @@
+import pytest
+
+from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train
+
+
+@pytest.mark.parametrize(
+ "rollout_env,expected_seeds",
+ [
+ pytest.param(
+ integration_env_config(
+ [
+ "--sglang-enable-deterministic-inference",
+ "--rollout-seed",
+ "42",
+ "--n-samples-per-prompt",
+ "3",
+ "--rollout-batch-size",
+ "1",
+ ]
+ ),
+ {42, 43, 44},
+ id="enabled",
+ ),
+ pytest.param(
+ integration_env_config(["--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]),
+ {None},
+ id="disabled",
+ ),
+ ],
+ indirect=["rollout_env"],
+)
+def test_sampling_seeds(rollout_env, expected_seeds):
+ env = rollout_env
+ load_and_call_train(env.args, env.data_source)
+
+ seeds = {req.get("sampling_params", {}).get("sampling_seed") for req in env.mock_server.request_log}
+ assert seeds == expected_seeds
diff --git a/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py
new file mode 100644
index 000000000..0ca5743ac
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py
@@ -0,0 +1,46 @@
+from contextlib import nullcontext
+
+import pytest
+from tests.fast.rollout.inference_rollout.integration.utils import (
+ MIXED_DATA_ROWS,
+ filter_by_reward,
+ integration_env_config,
+ load_and_call_train,
+)
+
+from miles.utils.misc import function_registry
+
+
+@pytest.mark.parametrize(
+ "rollout_env,use_filter,expect_all_correct",
+ [
+ pytest.param(
+ integration_env_config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS),
+ False,
+ False,
+ id="no_filter",
+ ),
+ pytest.param(
+ integration_env_config(
+ ["--rollout-batch-size", "3", "--dynamic-sampling-filter-path", "test:filter_by_reward"],
+ data_rows=MIXED_DATA_ROWS,
+ ),
+ True,
+ True,
+ id="with_filter",
+ ),
+ ],
+ indirect=["rollout_env"],
+)
+def test_filter_effect(rollout_env, use_filter, expect_all_correct):
+ env = rollout_env
+ ctx = function_registry.temporary("test:filter_by_reward", filter_by_reward) if use_filter else nullcontext()
+
+ with ctx:
+ out = load_and_call_train(env.args, env.data_source)
+
+ rewards = {group[0].reward for group in out.samples}
+ if expect_all_correct:
+ assert rewards == {1}, "Filter should keep only correct samples"
+ else:
+ assert 0 in rewards, "Without filter, incorrect samples should be present"
diff --git a/tests/fast/rollout/inference_rollout/integration/test_group_rm.py b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py
new file mode 100644
index 000000000..afd870c30
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py
@@ -0,0 +1,22 @@
+import pytest
+
+from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train
+
+
+@pytest.mark.parametrize(
+ "rollout_env",
+ [
+ pytest.param(
+ integration_env_config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]),
+ id="group_rm_enabled",
+ ),
+ ],
+ indirect=True,
+)
+def test_group_rm_rewards_set(rollout_env):
+ env = rollout_env
+ out = load_and_call_train(env.args, env.data_source)
+
+ assert len(out.samples) == env.args.rollout_batch_size
+ rewards = [sample.reward for group in out.samples for sample in group]
+ assert all(r in (0, 1) for r in rewards)
diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py
new file mode 100644
index 000000000..2b12d3d88
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py
@@ -0,0 +1,65 @@
+import pytest
+from tests.fast.fixtures.rollout_fixtures import DEFAULT_DATA_ROWS, RolloutEnvConfig
+from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_train
+
+from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput
+from miles.utils.misc import function_registry
+from miles.utils.types import Sample
+
+
+async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput:
+ sample = input.sample
+ s1 = Sample(
+ prompt=sample.prompt,
+ response="\\boxed{8}",
+ response_length=5,
+ tokens=sample.tokens + [59, 79075, 90, 23, 92],
+ label=sample.label,
+ reward=None,
+ status=Sample.Status.COMPLETED,
+ )
+ s2 = Sample(
+ prompt=sample.prompt,
+ response="\\boxed{8}",
+ response_length=5,
+ tokens=sample.tokens + [59, 79075, 90, 23, 92],
+ label=sample.label,
+ reward=0.5,
+ status=Sample.Status.COMPLETED,
+ )
+ return GenerateFnOutput(samples=[s1, s2])
+
+
+@pytest.mark.parametrize(
+ "rollout_env",
+ [
+ pytest.param(
+ RolloutEnvConfig(
+ extra_argv=MODULAR_ROLLOUT_BASE_ARGV
+ + [
+ "--custom-generate-function-path",
+ "test:multi_sample_generate",
+ "--rollout-batch-size",
+ "1",
+ "--n-samples-per-prompt",
+ "1",
+ ],
+ data_rows=DEFAULT_DATA_ROWS,
+ ),
+ id="multi_sample_output",
+ ),
+ ],
+ indirect=True,
+)
+def test_multi_sample_output_preserves_existing_reward(rollout_env):
+ env = rollout_env
+ with function_registry.temporary("test:multi_sample_generate", _multi_sample_generate):
+ out = load_and_call_train(env.args, env.data_source)
+
+ assert len(out.samples) == env.args.rollout_batch_size
+ group = out.samples[0]
+ assert isinstance(group[0], list)
+ samples = group[0]
+ assert len(samples) == 2
+ assert samples[0].reward == 1
+ assert samples[1].reward == 0.5
diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py
new file mode 100644
index 000000000..c41d71399
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py
@@ -0,0 +1,114 @@
+from typing import Any
+
+import pytest
+from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant
+from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig
+from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_rollout
+
+from miles.utils.test_utils.mock_tools import TwoTurnStub
+from miles.utils.types import Sample
+
+
+TWO_TURN_DATA_ROWS = [{"input": [{"role": "user", "content": TwoTurnStub.USER_QUESTION}], "label": "2008"}]
+
+_VARIANT_NAMES = [
+ "multi_turn_single_sample",
+ "multi_turn_multi_samples",
+ "agentic_tool_call_single_sample",
+ "agentic_tool_call_multi_samples",
+]
+
+_BASE_EXTRA_ARGV = [
+ "--rollout-batch-size",
+ "2",
+ "--n-samples-per-prompt",
+ "2",
+ "--n-samples-per-eval-prompt",
+ "2",
+ "--custom-rm-path",
+ "tests.fast.rollout.inference_rollout.integration.test_multi_turn._simple_reward_function",
+]
+
+
+def _config_for_variant(variant: str) -> RolloutEnvConfig:
+ return RolloutEnvConfig(
+ extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + _BASE_EXTRA_ARGV,
+ data_rows=TWO_TURN_DATA_ROWS,
+ )
+
+
+@pytest.mark.parametrize(
+ "variant,rollout_env",
+ [pytest.param(variant, _config_for_variant(variant), id=variant) for variant in _VARIANT_NAMES],
+ indirect=["rollout_env"],
+)
+@pytest.mark.parametrize("test_type", ["train", "eval"])
+def test_rollout(rollout_env, variant, test_type):
+ env = rollout_env
+ env.mock_server.process_fn = TwoTurnStub.process_fn
+
+ out = load_and_call_rollout(env.args, env.data_source, mode=test_type)
+
+ if test_type == "train":
+ assert len(out.samples) == env.args.rollout_batch_size
+ group = out.samples[0]
+ _verify_samples(variant, group)
+ else:
+ assert "toy" in out.data
+ samples = out.data["toy"]["samples"]
+ _verify_samples(variant, samples)
+
+
+def _verify_samples(variant: str, samples: list[Any]):
+ is_multi_samples = variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples")
+
+ if is_multi_samples:
+ if len(samples) > 0 and isinstance(samples[0], list):
+ # Train mode: list[list[Sample]], grouped by prompt
+ assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}"
+ for group_sample in samples:
+ assert isinstance(group_sample, list), "multi_samples variant should return list[Sample] per generate"
+ _verify_group_samples(group_sample)
+ else:
+ # Eval mode: list[Sample], flattened
+ # n_samples_per_eval_prompt=2, and each generate returns 2 turns, so 2*2=4 samples
+ assert (
+ len(samples) == 4
+ ), f"n_samples_per_eval_prompt=2, each generate returns 2 turns, so should have 4 samples, got {len(samples)}"
+ # Group samples by prompt (every 2 samples form a group)
+ group_samples_list = [samples[i : i + 2] for i in range(0, len(samples), 2)]
+ for group_samples in group_samples_list:
+ _verify_group_samples(group_samples)
+ else:
+ assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}"
+ for sample in samples:
+ assert isinstance(sample, Sample), "single_sample variant should return Sample, not list"
+ _verify_sample(sample)
+
+
+def _verify_group_samples(group_samples: list[Sample], expected_count: int = 2):
+ assert len(group_samples) == expected_count, f"Group should have {expected_count} samples (one per turn)"
+ for i, sample in enumerate(group_samples):
+ _verify_sample(sample, expect_answer=(i == len(group_samples) - 1))
+
+
+def _verify_sample(sample: Sample, expected_reward: float = 1.0, expect_answer: bool = True):
+ assert sample.status == Sample.Status.COMPLETED
+ assert sample.reward == expected_reward, f"Sample should have reward={expected_reward}"
+ if expect_answer:
+ assert "2008" in sample.response, "Response should contain final answer '2008'"
+
+
+async def _simple_reward_function(args, samples: Sample | list[Sample]) -> float | list[float]:
+ if isinstance(samples, list):
+ # For multi_samples variants, use the last sample's reward
+ if getattr(args, "generate_multi_samples", False):
+ return [_check_reward(samples[-1])] * len(samples)
+ else:
+ return [_check_reward(sample) for sample in samples]
+ else:
+ return _check_reward(samples)
+
+
+def _check_reward(sample: Sample) -> float:
+ return float(sample.response and (str(sample.label) in sample.response))
diff --git a/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py
new file mode 100644
index 000000000..0812962cc
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py
@@ -0,0 +1,48 @@
+import pytest
+from tests.fast.rollout.inference_rollout.integration.utils import (
+ filter_by_reward,
+ integration_env_config,
+ load_and_call_train,
+)
+
+from miles.utils.misc import function_registry
+
+_DATA_ROWS = [
+ {"input": "What is 1+7?", "label": "8"},
+ {"input": "What is 1+8?", "label": "wrong"},
+ {"input": "What is 1+9?", "label": "wrong"},
+ {"input": "What is 1+6?", "label": "wrong"},
+]
+
+_BASE_ARGV = [
+ "--over-sampling-batch-size",
+ "4",
+ "--dynamic-sampling-filter-path",
+ "test:filter_by_reward",
+]
+
+
+def _over_sampling_config(rollout_batch_size: int):
+ return integration_env_config(["--rollout-batch-size", str(rollout_batch_size)] + _BASE_ARGV, data_rows=_DATA_ROWS)
+
+
+@pytest.mark.parametrize(
+ "rollout_env,expected_rounds",
+ [
+ pytest.param(_over_sampling_config(1), 1, id="one_round"),
+ pytest.param(_over_sampling_config(2), 2, id="two_rounds"),
+ ],
+ indirect=["rollout_env"],
+)
+def test_over_sampling_rounds(rollout_env, expected_rounds):
+ env = rollout_env
+
+ with function_registry.temporary("test:filter_by_reward", filter_by_reward):
+ out = load_and_call_train(env.args, env.data_source)
+
+ assert len(out.samples) == env.args.rollout_batch_size
+ assert all(group[0].reward == 1 for group in out.samples)
+
+ requests_count = len(env.mock_server.request_log)
+ expected_requests = expected_rounds * env.args.over_sampling_batch_size
+ assert requests_count == expected_requests, f"Expected {expected_rounds} round(s) = {expected_requests} requests"
diff --git a/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py
new file mode 100644
index 000000000..36e78c16c
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py
@@ -0,0 +1,67 @@
+from unittest.mock import Mock
+
+import pytest
+from tests.fast.rollout.inference_rollout.integration.utils import (
+ filter_by_reward,
+ integration_env_config,
+ load_and_call_train,
+)
+
+from miles.utils.misc import function_registry
+
+# Data with only 2 reward=1 samples out of 4.
+# This ensures all 4 samples must be generated to collect 2 valid ones.
+_FILTER_TEST_DATA_ROWS = [
+ {"input": "What is 1+7?", "label": "8"}, # reward=1
+ {"input": "What is 1+8?", "label": "wrong"}, # reward=0
+ {"input": "What is 1+9?", "label": "wrong"}, # reward=0
+ {"input": "What is 1+6?", "label": "7"}, # reward=1
+]
+
+
+@pytest.mark.parametrize(
+ "rollout_env",
+ [
+ pytest.param(
+ integration_env_config(
+ [
+ "--rollout-batch-size",
+ "2",
+ "--over-sampling-batch-size",
+ "4",
+ "--dynamic-sampling-filter-path",
+ "test:filter_by_reward",
+ "--rollout-sample-filter-path",
+ "test:sample_filter",
+ "--rollout-all-samples-process-path",
+ "test:all_samples_process",
+ ],
+ data_rows=_FILTER_TEST_DATA_ROWS,
+ ),
+ id="sample_filter_vs_all_samples",
+ ),
+ ],
+ indirect=True,
+)
+def test_sample_filter_and_all_samples_process(rollout_env):
+ env = rollout_env
+ sample_filter_mock = Mock()
+ all_samples_process_mock = Mock()
+
+ with (
+ function_registry.temporary("test:filter_by_reward", filter_by_reward),
+ function_registry.temporary("test:sample_filter", sample_filter_mock),
+ function_registry.temporary("test:all_samples_process", all_samples_process_mock),
+ ):
+ load_and_call_train(env.args, env.data_source)
+
+ sample_filter_mock.assert_called_once()
+ _, filtered_data = sample_filter_mock.call_args[0]
+ rewards = [g[0][0].reward if isinstance(g[0], list) else g[0].reward for g in filtered_data]
+ assert all(r == 1 for r in rewards)
+
+ all_samples_process_mock.assert_called_once()
+ _, all_samples, data_source = all_samples_process_mock.call_args[0]
+ assert data_source is not None
+
+ assert len(all_samples) > len(filtered_data), "all_samples_process should see more samples than sample_filter"
diff --git a/tests/fast/rollout/inference_rollout/integration/test_semaphore.py b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py
new file mode 100644
index 000000000..889a9ff8a
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py
@@ -0,0 +1,33 @@
+import pytest
+
+from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train
+
+_DATA_ROWS = [{"input": f"What is 1+{i}?", "label": str(1 + i)} for i in range(10)]
+_BASE_ARGV = ["--rollout-batch-size", "4", "--n-samples-per-prompt", "2"]
+
+
+@pytest.mark.parametrize(
+ "rollout_env,expected_range",
+ [
+ pytest.param(
+ integration_env_config(
+ ["--sglang-server-concurrency", "1"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05
+ ),
+ (1, 1),
+ id="limit_1",
+ ),
+ pytest.param(
+ integration_env_config(
+ ["--sglang-server-concurrency", "999"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05
+ ),
+ (2, 999),
+ id="no_limit",
+ ),
+ ],
+ indirect=["rollout_env"],
+)
+def test_max_concurrent(rollout_env, expected_range):
+ env = rollout_env
+ load_and_call_train(env.args, env.data_source)
+ min_expected, max_expected = expected_range
+ assert min_expected <= env.mock_server.max_concurrent <= max_expected
diff --git a/tests/fast/rollout/inference_rollout/integration/utils.py b/tests/fast/rollout/inference_rollout/integration/utils.py
new file mode 100644
index 000000000..ad413cf94
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/integration/utils.py
@@ -0,0 +1,89 @@
+from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant
+from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig
+
+from miles.rollout.base_types import (
+ RolloutFnConstructorInput,
+ RolloutFnEvalInput,
+ RolloutFnOutput,
+ RolloutFnTrainInput,
+)
+from miles.rollout.filter_hub.base_types import DynamicFilterOutput
+from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function
+from miles.utils.types import Sample
+
+
+def expected_sample(*, group_index: int | None) -> Sample:
+ return Sample(
+ group_index=group_index,
+ index=0,
+ prompt="What is 1+7?",
+ tokens=[3838, 374, 220, 16, 10, 22, 30, 59, 79075, 90, 23, 92],
+ multimodal_inputs=None,
+ multimodal_train_inputs=None,
+ response="\\boxed{8}",
+ response_length=5,
+ label="8",
+ reward=1,
+ loss_mask=None,
+ weight_versions=[],
+ rollout_log_probs=[-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125],
+ rollout_routed_experts=None,
+ remove_sample=False,
+ status=Sample.Status.COMPLETED,
+ metadata={},
+ train_metadata=None,
+ non_generation_time=0.0,
+ spec_info=Sample.SpecInfo(
+ spec_accept_token_num=0, spec_draft_token_num=0, spec_verify_ct=0, completion_token_num=0
+ ),
+ prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=0, total_prompt_tokens=7),
+ )
+
+
+MODULAR_ROLLOUT_BASE_ARGV = [
+ "--rollout-function-path",
+ "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn",
+]
+
+MIXED_DATA_ROWS = [
+ {"input": "What is 1+7?", "label": "8"},
+ {"input": "What is 1+8?", "label": "9"},
+ {"input": "What is 1+9?", "label": "wrong"},
+ {"input": "What is 1+6?", "label": "7"},
+]
+
+
+def integration_env_config(
+ extra_argv: list[str],
+ data_rows: list[dict] | None = None,
+ latency: float = 0.0,
+ variant: str = "single_turn",
+):
+ return RolloutEnvConfig(
+ extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + extra_argv,
+ data_rows=data_rows,
+ latency=latency,
+ )
+
+
+def load_and_call_rollout(args, data_source, mode: str = "train") -> RolloutFnOutput:
+ function_path = args.rollout_function_path if mode == "train" else args.eval_function_path
+ fn = load_rollout_function(
+ RolloutFnConstructorInput(args=args, data_source=data_source),
+ function_path,
+ )
+ if mode == "train":
+ return call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0))
+ else:
+ return call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0))
+
+
+def load_and_call_train(args, data_source):
+ return load_and_call_rollout(args, data_source, mode="train")
+
+
+def filter_by_reward(args, samples, **kwargs):
+ reward = samples[0].reward if not isinstance(samples[0], list) else samples[0][0].reward
+ if reward == 1:
+ return DynamicFilterOutput(keep=True)
+ return DynamicFilterOutput(keep=False, reason="reward_zero")
diff --git a/tests/fast/rollout/inference_rollout/test_compatibility.py b/tests/fast/rollout/inference_rollout/test_compatibility.py
new file mode 100644
index 000000000..ddfecd067
--- /dev/null
+++ b/tests/fast/rollout/inference_rollout/test_compatibility.py
@@ -0,0 +1,196 @@
+import asyncio
+from unittest.mock import MagicMock
+
+import pytest
+
+from miles.rollout.base_types import (
+ GenerateFnInput,
+ GenerateFnOutput,
+ RolloutFnConstructorInput,
+ RolloutFnEvalInput,
+ RolloutFnEvalOutput,
+ RolloutFnTrainInput,
+ RolloutFnTrainOutput,
+)
+from miles.rollout.inference_rollout.compatibility import (
+ LegacyGenerateFnAdapter,
+ LegacyRolloutFnAdapter,
+ call_rollout_function,
+ load_generate_function,
+ load_rollout_function,
+)
+from miles.utils.async_utils import run
+from miles.utils.misc import function_registry
+
+
+@pytest.fixture
+def constructor_input():
+ return RolloutFnConstructorInput(args="dummy_args", data_source="dummy_data_source")
+
+
+@pytest.fixture
+def make_generate_fn_input():
+ def _make(evaluation: bool = False):
+ state = MagicMock()
+ state.args = MagicMock()
+
+ return GenerateFnInput(
+ state=state,
+ sample={"text": "test prompt"},
+ sampling_params={"temperature": 0.7},
+ evaluation=evaluation,
+ )
+
+ return _make
+
+
+class TestSupportedRolloutFormats:
+ """
+ Documentation test to show various supported rollout function formats
+ """
+
+ @pytest.mark.parametrize("evaluation", [False, True])
+ def test_format_1_legacy_function_raw_output(self, constructor_input, evaluation):
+ def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False):
+ if evaluation:
+ return {"metric": {"accuracy": 0.9}}
+ return [[{"text": "sample"}]]
+
+ with function_registry.temporary("test:legacy_rollout", legacy_rollout_fn):
+ fn = load_rollout_function(constructor_input, "test:legacy_rollout")
+
+ input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput
+ result = call_rollout_function(fn, input_cls(rollout_id=1))
+
+ assert isinstance(fn, LegacyRolloutFnAdapter)
+ if evaluation:
+ assert isinstance(result, RolloutFnEvalOutput)
+ assert result.data == {"metric": {"accuracy": 0.9}}
+ else:
+ assert isinstance(result, RolloutFnTrainOutput)
+ assert result.samples == [[{"text": "sample"}]]
+
+ @pytest.mark.parametrize("evaluation", [False, True])
+ def test_format_2_legacy_function_typed_output(self, constructor_input, evaluation):
+ def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False):
+ if evaluation:
+ return RolloutFnEvalOutput(data={"ds": {"acc": 0.95}})
+ return RolloutFnTrainOutput(samples=[[{"text": "typed"}]])
+
+ with function_registry.temporary("test:legacy_typed", legacy_rollout_fn):
+ fn = load_rollout_function(constructor_input, "test:legacy_typed")
+
+ input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput
+ result = call_rollout_function(fn, input_cls(rollout_id=1))
+
+ if evaluation:
+ assert isinstance(result, RolloutFnEvalOutput)
+ assert result.data == {"ds": {"acc": 0.95}}
+ else:
+ assert isinstance(result, RolloutFnTrainOutput)
+ assert result.samples == [[{"text": "typed"}]]
+
+ @pytest.mark.parametrize("evaluation", [False, True])
+ def test_format_3_sync_class(self, constructor_input, evaluation):
+ class SyncRolloutFn:
+ def __init__(self, input: RolloutFnConstructorInput):
+ pass
+
+ def __call__(self, input):
+ if input.evaluation:
+ return RolloutFnEvalOutput(data={"test": {"score": 1}})
+ return RolloutFnTrainOutput(samples=[[{"text": "sync"}]])
+
+ with function_registry.temporary("test:sync_class", SyncRolloutFn):
+ fn = load_rollout_function(constructor_input, "test:sync_class")
+
+ input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput
+ result = call_rollout_function(fn, input_cls(rollout_id=1))
+
+ assert isinstance(fn, SyncRolloutFn)
+ expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput
+ assert isinstance(result, expected_type)
+
+ @pytest.mark.parametrize("evaluation", [False, True])
+ def test_format_4_async_class(self, constructor_input, evaluation):
+ class AsyncRolloutFn:
+ def __init__(self, input: RolloutFnConstructorInput):
+ pass
+
+ async def __call__(self, input):
+ await asyncio.sleep(0.001)
+ if input.evaluation:
+ return RolloutFnEvalOutput(data={"benchmark": {"accuracy": 0.98}})
+ return RolloutFnTrainOutput(samples=[[{"text": "async"}]])
+
+ with function_registry.temporary("test:async_class", AsyncRolloutFn):
+ fn = load_rollout_function(constructor_input, "test:async_class")
+
+ input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput
+ result = call_rollout_function(fn, input_cls(rollout_id=1))
+
+ assert isinstance(fn, AsyncRolloutFn)
+ expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput
+ assert isinstance(result, expected_type)
+
+
+class TestSupportedGenerateFormats:
+ """
+ Documentation test similar to TestSupportedRolloutFormats
+ """
+
+ @pytest.mark.parametrize("evaluation", [False, True])
+ def test_format_1_legacy_function_with_evaluation_param(self, make_generate_fn_input, evaluation):
+ async def legacy_generate_fn(args, sample, sampling_params, evaluation=False):
+ return "my_sample"
+
+ with function_registry.temporary("test:legacy_gen_eval", legacy_generate_fn):
+ fn = load_generate_function("test:legacy_gen_eval")
+
+ result = run(fn(make_generate_fn_input(evaluation)))
+
+ assert isinstance(fn, LegacyGenerateFnAdapter)
+ assert isinstance(result, GenerateFnOutput)
+ assert result.samples == "my_sample"
+
+ @pytest.mark.parametrize("evaluation", [False, True])
+ def test_format_2_legacy_function_without_evaluation_param(self, make_generate_fn_input, evaluation):
+ async def legacy_generate_fn(args, sample, sampling_params):
+ return "my_sample"
+
+ with function_registry.temporary("test:legacy_gen", legacy_generate_fn):
+ fn = load_generate_function("test:legacy_gen")
+
+ result = run(fn(make_generate_fn_input(evaluation)))
+
+ assert isinstance(fn, LegacyGenerateFnAdapter)
+ assert isinstance(result, GenerateFnOutput)
+ assert result.samples == "my_sample"
+
+ @pytest.mark.parametrize("evaluation", [False, True])
+ def test_format_3_new_async_function_api(self, make_generate_fn_input, evaluation):
+ async def generate(input: GenerateFnInput) -> GenerateFnOutput:
+ return GenerateFnOutput(samples="my_sample")
+
+ with function_registry.temporary("test:new_async", generate):
+ fn = load_generate_function("test:new_async")
+
+ result = run(fn(make_generate_fn_input(evaluation)))
+
+ assert isinstance(result, GenerateFnOutput)
+ assert result.samples == "my_sample"
+
+ @pytest.mark.parametrize("evaluation", [False, True])
+ def test_format_4_new_class_api(self, make_generate_fn_input, evaluation):
+ class MyGenerateFn:
+ async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput:
+ return GenerateFnOutput(samples="my_sample")
+
+ with function_registry.temporary("test:new_class", MyGenerateFn):
+ fn = load_generate_function("test:new_class")
+
+ result = run(fn(make_generate_fn_input(evaluation)))
+
+ assert isinstance(fn, MyGenerateFn)
+ assert isinstance(result, GenerateFnOutput)
+ assert result.samples == "my_sample"
diff --git a/tests/fast/rollout/rm_hub/__init__.py b/tests/fast/rollout/rm_hub/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/rollout/rm_hub/test_deepscaler.py b/tests/fast/rollout/rm_hub/test_deepscaler.py
new file mode 100644
index 000000000..bd4c606a6
--- /dev/null
+++ b/tests/fast/rollout/rm_hub/test_deepscaler.py
@@ -0,0 +1,26 @@
+import pytest
+
+from miles.rollout.rm_hub.deepscaler import get_deepscaler_rule_based_reward
+
+
+class TestGetDeepscalerRuleBasedReward:
+ @pytest.mark.parametrize(
+ "response,label,expected",
+ [
+ (r"Let me analyze...The answer is \boxed{42}", "42", 1),
+ (r"Thinking...The answer is \boxed{wrong}", "42", 0),
+ (r"###Response\boxed{42}", "42", 1),
+ (r"###Response\boxed{wrong}", "42", 0),
+ (r"The answer is \boxed{42}", "42", 0),
+ (r"The answer is 42", "42", 0),
+ (r"\boxed{42}", "", 0),
+ (r"\boxed{42}", r"\boxed{42}", 1),
+ (r"\boxed{123}", 123, 1),
+ (r"\boxed{3.14}", 3.14, 1),
+ (r"\boxed{1/2}", "0.5", 1),
+ (r"\boxed{\frac{1}{2}}", "0.5", 1),
+ (r"First thoughtSecond thought\boxed{42}", "42", 1),
+ ],
+ )
+ def test_get_deepscaler_rule_based_reward(self, response, label, expected):
+ assert get_deepscaler_rule_based_reward(response, label) == expected
diff --git a/tests/fast/rollout/rm_hub/test_f1.py b/tests/fast/rollout/rm_hub/test_f1.py
new file mode 100644
index 000000000..c9ecf9614
--- /dev/null
+++ b/tests/fast/rollout/rm_hub/test_f1.py
@@ -0,0 +1,44 @@
+import pytest
+
+from miles.rollout.rm_hub.f1 import f1_score, normalize_answer
+
+
+class TestNormalizeAnswer:
+ @pytest.mark.parametrize(
+ "input_str,expected",
+ [
+ ("Hello World", "hello world"),
+ ("The quick brown fox", "quick brown fox"),
+ ("A cat and a dog", "cat and dog"),
+ ("Hello, world!", "hello world"),
+ (" multiple spaces ", "multiple spaces"),
+ ("An apple", "apple"),
+ ("UPPERCASE", "uppercase"),
+ ],
+ )
+ def test_normalize_answer(self, input_str, expected):
+ assert normalize_answer(input_str) == expected
+
+
+class TestF1Score:
+ @pytest.mark.parametrize(
+ "prediction,ground_truth,expected_f1,expected_prec,expected_recall",
+ [
+ ("hello world", "hello world", 1.0, 1.0, 1.0),
+ ("hello world foo", "hello world bar", 2 / 3, 2 / 3, 2 / 3),
+ ("abc", "xyz", 0, 0, 0),
+ (None, "anything", 0, 0, 0),
+ ("yes", "no", 0, 0, 0),
+ ("no", "yes", 0, 0, 0),
+ ("yes", "yes", 1.0, 1.0, 1.0),
+ ("noanswer", "yes", 0, 0, 0),
+ ("the answer is correct", "answer is correct", 1.0, 1.0, 1.0),
+ ("hello, world!", "hello world", 1.0, 1.0, 1.0),
+ ("hello", "hello world", pytest.approx(2 / 3), 1.0, 0.5),
+ ],
+ )
+ def test_f1_score(self, prediction, ground_truth, expected_f1, expected_prec, expected_recall):
+ f1, prec, recall = f1_score(prediction, ground_truth)
+ assert f1 == expected_f1
+ assert prec == expected_prec
+ assert recall == expected_recall
diff --git a/tests/fast/rollout/rm_hub/test_gpqa.py b/tests/fast/rollout/rm_hub/test_gpqa.py
new file mode 100644
index 000000000..45cefd201
--- /dev/null
+++ b/tests/fast/rollout/rm_hub/test_gpqa.py
@@ -0,0 +1,86 @@
+import pytest
+
+from miles.rollout.rm_hub.gpqa import (
+ _extract_letter_from_response,
+ _normalize_text,
+ _strip_chain_of_thought,
+ compute_gpqa_reward,
+)
+
+
+class TestStripChainOfThought:
+ @pytest.mark.parametrize(
+ "text,expected",
+ [
+ ("Let me think...The answer is A", "The answer is A"),
+ ("The answer is A", "The answer is A"),
+ ("", ""),
+ (None, ""),
+ ],
+ )
+ def test_strip_chain_of_thought(self, text, expected):
+ assert _strip_chain_of_thought(text) == expected
+
+
+class TestNormalizeText:
+ @pytest.mark.parametrize(
+ "input_str,expected",
+ [
+ ("Hello World", "hello world"),
+ ("Test-123", "test 123"),
+ ("A, B, C", "a b c"),
+ ("", ""),
+ ],
+ )
+ def test_normalize_text(self, input_str, expected):
+ assert _normalize_text(input_str) == expected
+
+
+class TestExtractLetterFromResponse:
+ @pytest.mark.parametrize(
+ "response,expected",
+ [
+ ("The answer is A", "A"),
+ ("answer: B", "B"),
+ ("I think C is correct", "C"),
+ ("final answer: D", "D"),
+ ("Option A is the best choice", "A"),
+ ("The answer is B", "B"),
+ ("After analysis, my choice is C", "C"),
+ ("A B C D", "D"),
+ ("No valid letter here", None),
+ ("", None),
+ (None, None),
+ ("The answer is Z", None),
+ ],
+ )
+ def test_extract_letter(self, response, expected):
+ assert _extract_letter_from_response(response, "ABCD") == expected
+
+
+class TestComputeGpqaReward:
+ @pytest.mark.parametrize(
+ "response,label,metadata,expected",
+ [
+ ("Answer: A", "A", None, 1.0),
+ ("Answer: A", "B", None, 0.0),
+ (None, "A", None, 0.0),
+ ("Answer: B", "ignored", {"correct_letter": "B"}, 1.0),
+ ("Answer: A", "ignored", {"correct_letter": "B"}, 0.0),
+ ("Answer: A", 0, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0),
+ ("Answer: B", 1, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0),
+ ("Answer: X", "X", {"valid_letters": ["X", "Y", "Z"]}, 1.0),
+ ("Answer: A", "X", {"valid_letters": ["X", "Y", "Z"]}, 0.0),
+ (
+ "I believe the answer is Paris",
+ "",
+ {"choices": ["Paris", "London", "Berlin", "Rome"], "correct_letter": "A"},
+ 1.0,
+ ),
+ ("Answer: A", "", {"choices": {"A": "Paris", "B": "London"}, "correct_letter": "A"}, 1.0),
+ ("The answer is Paris", "Paris", {"choices": ["Paris", "London", "Berlin", "Rome"]}, 1.0),
+ ("Let me think step by step...The answer is A", "A", None, 1.0),
+ ],
+ )
+ def test_compute_gpqa_reward(self, response, label, metadata, expected):
+ assert compute_gpqa_reward(response, label, metadata=metadata) == expected
diff --git a/tests/fast/rollout/rm_hub/test_math_dapo_utils.py b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py
new file mode 100644
index 000000000..56a7f6d1f
--- /dev/null
+++ b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py
@@ -0,0 +1,108 @@
+import pytest
+
+from miles.rollout.rm_hub.math_dapo_utils import (
+ compute_score,
+ is_correct_minerva,
+ is_correct_strict_box,
+ last_boxed_only_string,
+ normalize_final_answer,
+ remove_boxed,
+)
+
+
+class TestLastBoxedOnlyString:
+ @pytest.mark.parametrize(
+ "input_str,expected",
+ [
+ (r"The answer is \boxed{42}", r"\boxed{42}"),
+ (r"\boxed{x^2}", r"\boxed{x^2}"),
+ (r"No boxed", None),
+ (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"),
+ ],
+ )
+ def test_last_boxed_only_string(self, input_str, expected):
+ assert last_boxed_only_string(input_str) == expected
+
+
+class TestRemoveBoxed:
+ @pytest.mark.parametrize(
+ "input_str,expected",
+ [
+ (r"\boxed{42}", "42"),
+ (r"\boxed{x + 1}", "x + 1"),
+ ],
+ )
+ def test_remove_boxed_valid(self, input_str, expected):
+ assert remove_boxed(input_str) == expected
+
+ def test_remove_boxed_invalid(self):
+ with pytest.raises(AssertionError):
+ remove_boxed("not boxed")
+
+
+class TestNormalizeFinalAnswer:
+ @pytest.mark.parametrize(
+ "input_str,expected",
+ [
+ ("42", "42"),
+ (" 42 ", "42"),
+ (r"\text{hello}", "hello"),
+ (r"\textbf{bold}", "bold"),
+ (r"x = 42", "42"),
+ (r"100 square", "100"),
+ (r"$50$ dollars", "50"),
+ (r"\boxed{42}", "42"),
+ (r"\frac12", r"\frac{1}{2}"),
+ (r"\sqrt3", r"\sqrt{3}"),
+ ("1,000", "1000"),
+ ("<|im_end|>", ""),
+ ],
+ )
+ def test_normalize_final_answer(self, input_str, expected):
+ assert normalize_final_answer(input_str) == expected
+
+
+class TestIsCorrectMinerva:
+ @pytest.mark.parametrize(
+ "solution,gt,gt_need_extract,expected_correct",
+ [
+ ("Answer: 42", "42", False, True),
+ ("Answer: 100", "42", False, False),
+ ("Answer: wrong", "42", False, False),
+ ("Answer: 42", r"\boxed{42}", True, True),
+ ],
+ )
+ def test_is_correct_minerva(self, solution, gt, gt_need_extract, expected_correct):
+ correct, pred = is_correct_minerva(solution, gt, gt_need_extract=gt_need_extract)
+ assert correct == expected_correct
+
+
+class TestIsCorrectStrictBox:
+ @pytest.mark.parametrize(
+ "pred,gt,expected_score,expected_pred",
+ [
+ (r"blah blah \boxed{42}", "42", 1, "42"),
+ (r"\boxed{wrong}", "42", -1, "wrong"),
+ ("no box here", "42", -1, None),
+ ],
+ )
+ def test_is_correct_strict_box(self, pred, gt, expected_score, expected_pred):
+ score, extracted = is_correct_strict_box(pred, gt)
+ assert score == expected_score
+ assert extracted == expected_pred
+
+
+class TestComputeScore:
+ @pytest.mark.parametrize(
+ "solution,gt,strict_box,expected_score,expected_acc",
+ [
+ ("Answer: 42", "42", False, 1.0, True),
+ ("Answer: wrong", "42", False, -1.0, False),
+ (r"\boxed{42}", "42", True, 1.0, True),
+ ("x" * 500 + " Answer: 42", "42", False, 1.0, True),
+ ],
+ )
+ def test_compute_score(self, solution, gt, strict_box, expected_score, expected_acc):
+ result = compute_score(solution, gt, strict_box_verify=strict_box)
+ assert result["score"] == expected_score
+ assert result["acc"] == expected_acc
diff --git a/tests/fast/rollout/rm_hub/test_math_utils.py b/tests/fast/rollout/rm_hub/test_math_utils.py
new file mode 100644
index 000000000..2423ed4ac
--- /dev/null
+++ b/tests/fast/rollout/rm_hub/test_math_utils.py
@@ -0,0 +1,129 @@
+import pytest
+
+from miles.rollout.rm_hub.math_utils import (
+ _normalize,
+ extract_answer,
+ grade_answer_mathd,
+ grade_answer_sympy,
+ grade_answer_verl,
+ last_boxed_only_string,
+ remove_boxed,
+)
+
+
+class TestLastBoxedOnlyString:
+ @pytest.mark.parametrize(
+ "input_str,expected",
+ [
+ (r"The answer is \boxed{42}", r"\boxed{42}"),
+ (r"\boxed{x^2 + 1}", r"\boxed{x^2 + 1}"),
+ (r"So \boxed{\frac{1}{2}}", r"\boxed{\frac{1}{2}}"),
+ (r"No boxed here", None),
+ (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"),
+ (r"\boxed{nested {braces}}", r"\boxed{nested {braces}}"),
+ (r"\fbox{fbox content}", r"\fbox{fbox content}"),
+ ("", None),
+ ],
+ )
+ def test_last_boxed_only_string(self, input_str, expected):
+ assert last_boxed_only_string(input_str) == expected
+
+
+class TestRemoveBoxed:
+ @pytest.mark.parametrize(
+ "input_str,expected",
+ [
+ (r"\boxed{42}", "42"),
+ (r"\boxed{x^2 + 1}", "x^2 + 1"),
+ (r"\boxed{\frac{1}{2}}", r"\frac{1}{2}"),
+ ("not boxed", None),
+ ],
+ )
+ def test_remove_boxed(self, input_str, expected):
+ assert remove_boxed(input_str) == expected
+
+
+class TestExtractAnswer:
+ @pytest.mark.parametrize(
+ "input_str,expected",
+ [
+ (r"The answer is \boxed{42}", "42"),
+ (r"So \boxed{\frac{1}{2}}", r"\frac{1}{2}"),
+ (r"Multiple \boxed{1} then \boxed{final}", "final"),
+ (r"No boxed here", None),
+ ("", None),
+ ],
+ )
+ def test_extract_answer(self, input_str, expected):
+ assert extract_answer(input_str) == expected
+
+
+class TestNormalize:
+ @pytest.mark.parametrize(
+ "input_str,expected",
+ [
+ ("1,000", "1000"),
+ (r"\text{hello}", "hello"),
+ (" 42 ", "42"),
+ (r"100%", "100"),
+ (r"\$50", "50"),
+ ("HELLO", "hello"),
+ ("1,234,567", "1234567"),
+ (None, None),
+ ],
+ )
+ def test_normalize(self, input_str, expected):
+ assert _normalize(input_str) == expected
+
+
+class TestGradeAnswerMathd:
+ @pytest.mark.parametrize(
+ "given,ground_truth,expected",
+ [
+ ("42", "42", True),
+ (" 42 ", "42", True),
+ (r"\frac{1}{2}", r"\frac{1}{2}", True),
+ ("wrong", "42", False),
+ ("", "42", False),
+ ],
+ )
+ def test_grade_answer_mathd(self, given, ground_truth, expected):
+ assert grade_answer_mathd(given, ground_truth) == expected
+
+
+class TestGradeAnswerSympy:
+ @pytest.mark.parametrize(
+ "given,ground_truth,expected",
+ [
+ ("42", "42", True),
+ ("x^2", "x^2", True),
+ ("1/2", "0.5", True),
+ (r"\frac{1}{2}", "0.5", True),
+ ("wrong", "42", False),
+ ("", "42", False),
+ ("(1,2)", "(1,2)", True),
+ ("(1,2,3)", "(1,2)", False),
+ ("42", None, False),
+ ],
+ )
+ def test_grade_answer_sympy(self, given, ground_truth, expected):
+ assert grade_answer_sympy(given, ground_truth) == expected
+
+
+class TestGradeAnswerVerl:
+ @pytest.mark.parametrize(
+ "solution,ground_truth,expected",
+ [
+ (r"\boxed{42}", "42", True),
+ (r"The answer is \boxed{42}", "42", True),
+ (r"\boxed{1/2}", r"\frac{1}{2}", True),
+ (r"\boxed{wrong}", "42", False),
+ ("no boxed", "42", False),
+ (r"\boxed{42}", r"\boxed{42}", True),
+ ("", "42", False),
+ (r"\boxed{42}", "", False),
+ (r"\boxed{42}", None, False),
+ ],
+ )
+ def test_grade_answer_verl(self, solution, ground_truth, expected):
+ assert grade_answer_verl(solution, ground_truth) == expected
diff --git a/tests/fast/rollout/rm_hub/test_rm_hub.py b/tests/fast/rollout/rm_hub/test_rm_hub.py
new file mode 100644
index 000000000..a3dadbdaf
--- /dev/null
+++ b/tests/fast/rollout/rm_hub/test_rm_hub.py
@@ -0,0 +1,126 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from miles.rollout.rm_hub import async_rm, batched_async_rm
+from miles.utils.async_utils import run
+from miles.utils.types import Sample
+
+
+@pytest.fixture
+def mock_args():
+ args = MagicMock()
+ args.custom_rm_path = None
+ args.rm_type = None
+ args.rm_url = None
+ return args
+
+
+class TestAsyncRm:
+ @pytest.mark.parametrize(
+ "rm_type,response,label,expected",
+ [
+ ("math", r"\boxed{42}", "42", 1),
+ ("math", r"\boxed{wrong}", "42", 0),
+ ("f1", "hello world", "hello world", 1.0),
+ ("dapo", "Answer: 42", "42", {"score": 1.0}),
+ ("deepscaler", r"\boxed{42}", "42", 1),
+ ("gpqa", "Answer: A", "A", 1.0),
+ ("boxed_f1", r"Final answer is \boxed{hello world}", "hello world", 1.0),
+ ],
+ )
+ def test_rm_types(self, mock_args, rm_type, response, label, expected):
+ mock_args.rm_type = rm_type
+ sample = Sample(prompt="", response=response, label=label)
+ reward = run(async_rm(mock_args, sample))
+ if isinstance(expected, dict):
+ for k, v in expected.items():
+ assert reward[k] == v
+ else:
+ assert reward == expected
+
+ def test_f1_rm_partial(self, mock_args):
+ mock_args.rm_type = "f1"
+ sample = Sample(prompt="", response="hello", label="hello world")
+ reward = run(async_rm(mock_args, sample))
+ assert 0 < reward < 1
+
+ def test_random_rm(self, mock_args):
+ mock_args.rm_type = "random"
+ sample = Sample(prompt="", response="anything", label="anything")
+ reward = run(async_rm(mock_args, sample))
+ assert reward in [0, 1]
+
+ def test_rm_type_from_metadata(self, mock_args):
+ mock_args.rm_type = None
+ sample = Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"})
+ reward = run(async_rm(mock_args, sample))
+ assert reward == 1
+
+ @pytest.mark.parametrize(
+ "rm_type,match",
+ [
+ ("unknown_type", "not implemented"),
+ ("", "not specified"),
+ ],
+ )
+ def test_invalid_rm_type_raises(self, mock_args, rm_type, match):
+ mock_args.rm_type = rm_type
+ sample = Sample(prompt="", response="test", label="test")
+ with pytest.raises(NotImplementedError, match=match):
+ run(async_rm(mock_args, sample))
+
+
+class TestBatchedAsyncRm:
+ @pytest.mark.parametrize(
+ "rm_type,samples_data,expected",
+ [
+ (
+ "math",
+ [(r"\boxed{42}", "42"), (r"\boxed{100}", "100"), (r"\boxed{wrong}", "42")],
+ [1, 1, 0],
+ ),
+ (
+ "f1",
+ [("hello world", "hello world"), ("different", "something else")],
+ [1.0, 0],
+ ),
+ ],
+ )
+ def test_batched_rm(self, mock_args, rm_type, samples_data, expected):
+ mock_args.rm_type = rm_type
+ samples = [Sample(prompt="", response=r, label=label) for r, label in samples_data]
+ rewards = run(batched_async_rm(mock_args, samples))
+ assert rewards == expected
+
+ def test_inplace_set_reward_field(self, mock_args):
+ mock_args.rm_type = "math"
+ samples = [
+ Sample(prompt="", response=r"\boxed{42}", label="42"),
+ Sample(prompt="", response=r"\boxed{100}", label="100"),
+ ]
+ result = run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True))
+ assert result is None
+ assert samples[0].reward == 1
+ assert samples[1].reward == 1
+
+ def test_inplace_raises_on_existing_reward(self, mock_args):
+ mock_args.rm_type = "math"
+ samples = [Sample(prompt="", response=r"\boxed{42}", label="42", reward=0.5)]
+ with pytest.raises(AssertionError, match="Overriding"):
+ run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True))
+
+ def test_empty_samples(self, mock_args):
+ mock_args.rm_type = "math"
+ rewards = run(batched_async_rm(mock_args, []))
+ assert rewards == []
+
+ def test_mixed_rm_types_via_metadata(self, mock_args):
+ mock_args.rm_type = None
+ samples = [
+ Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}),
+ Sample(prompt="", response="hello", label="hello", metadata={"rm_type": "f1"}),
+ ]
+ rewards = run(batched_async_rm(mock_args, samples))
+ assert rewards[0] == 1
+ assert rewards[1] == 1.0
diff --git a/tests/fast/router/__init__.py b/tests/fast/router/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/router/test_router.py b/tests/fast/router/test_router.py
new file mode 100644
index 000000000..8bad5874c
--- /dev/null
+++ b/tests/fast/router/test_router.py
@@ -0,0 +1,207 @@
+import asyncio
+from argparse import Namespace
+
+import pytest
+import requests
+
+from miles.router.router import MilesRouter
+from miles.utils.http_utils import find_available_port
+from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, default_process_fn
+from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer
+
+
+def make_router_args(router_port: int, **overrides) -> Namespace:
+ defaults = dict(
+ sglang_router_ip="127.0.0.1",
+ sglang_router_port=router_port,
+ rollout_health_check_interval=1.0,
+ miles_router_health_check_failure_threshold=3,
+ miles_router_max_connections=100,
+ miles_router_timeout=None,
+ miles_router_middleware_paths=[],
+ hf_checkpoint="Qwen/Qwen3-0.6B",
+ cross_turn_token_out=False,
+ inherit_last_assistant=False,
+ )
+ defaults.update(overrides)
+ return Namespace(**defaults)
+
+
+def create_mock_worker(start_port: int = 30000) -> MockSGLangServer:
+ port = find_available_port(start_port)
+ return MockSGLangServer(
+ model_name="Qwen/Qwen3-0.6B",
+ process_fn=default_process_fn,
+ host="127.0.0.1",
+ port=port,
+ latency=0.0,
+ )
+
+
+class RouterEnv:
+ def __init__(self, router: MilesRouter, server: UvicornThreadServer):
+ self.router = router
+ self.server = server
+
+ @property
+ def url(self) -> str:
+ return self.server.url
+
+
+@pytest.fixture
+def router_env():
+ args = make_router_args(find_available_port(20000))
+ router = MilesRouter(args, verbose=False)
+ server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port)
+ server.start()
+ yield RouterEnv(router, server)
+ server.stop()
+
+
+@pytest.fixture
+def mock_worker():
+ server = create_mock_worker()
+ server.start()
+ yield server
+ server.stop()
+
+
+@pytest.fixture
+def mock_worker_factory():
+ servers = []
+
+ def _create():
+ start_port = 30000 + len(servers) * 100
+ server = create_mock_worker(start_port)
+ server.start()
+ servers.append(server)
+ return server
+
+ yield _create
+ for s in servers:
+ s.stop()
+
+
+@pytest.fixture
+def router_factory():
+ def _create(**overrides) -> MilesRouter:
+ args = make_router_args(find_available_port(20000), **overrides)
+ return MilesRouter(args, verbose=False)
+
+ return _create
+
+
+class TestWorkerManagement:
+ def test_add_worker_via_query_param(self, router_env: RouterEnv):
+ worker_url = "http://127.0.0.1:30001"
+ r = requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0)
+ r.raise_for_status()
+
+ assert r.json()["status"] == "success"
+ assert worker_url in router_env.router.worker_request_counts
+ assert router_env.router.worker_request_counts[worker_url] == 0
+
+ def test_add_worker_via_body(self, router_env: RouterEnv):
+ worker_url = "http://127.0.0.1:30002"
+ r = requests.post(f"{router_env.url}/add_worker", json={"url": worker_url}, timeout=5.0)
+ r.raise_for_status()
+
+ assert r.json()["status"] == "success"
+ assert worker_url in router_env.router.worker_request_counts
+
+ def test_add_worker_duplicate(self, router_env: RouterEnv):
+ worker_url = "http://127.0.0.1:30003"
+ requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status()
+ requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status()
+
+ assert len(router_env.router.worker_request_counts) == 1
+ assert worker_url in router_env.router.worker_request_counts
+
+ def test_add_worker_missing_url(self, router_env: RouterEnv):
+ r = requests.post(f"{router_env.url}/add_worker", json={}, timeout=5.0)
+ assert r.status_code == 400
+ assert "error" in r.json()
+
+ def test_list_workers(self, router_env: RouterEnv):
+ worker_urls = ["http://127.0.0.1:30001", "http://127.0.0.1:30002"]
+ for url in worker_urls:
+ requests.post(f"{router_env.url}/add_worker", params={"url": url}, timeout=5.0)
+
+ r = requests.get(f"{router_env.url}/list_workers", timeout=5.0)
+ r.raise_for_status()
+ assert set(r.json()["urls"]) == set(worker_urls)
+
+
+class TestLoadBalancing:
+ def test_use_url_selects_min_load(self, router_factory):
+ router = router_factory()
+ router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8}
+
+ selected = router._use_url()
+ assert selected == "http://w2:8000"
+ assert router.worker_request_counts["http://w2:8000"] == 3
+
+ def test_use_url_excludes_dead_workers(self, router_factory):
+ router = router_factory()
+ router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 1, "http://w3:8000": 3}
+ router.dead_workers = {"http://w2:8000"}
+
+ selected = router._use_url()
+ assert selected == "http://w3:8000"
+ assert router.worker_request_counts["http://w3:8000"] == 4
+
+ def test_use_url_raises_when_all_dead(self, router_factory):
+ router = router_factory()
+ router.worker_request_counts = {"http://w1:8000": 0}
+ router.dead_workers = {"http://w1:8000"}
+
+ with pytest.raises(RuntimeError, match="No healthy workers"):
+ router._use_url()
+
+
+# TODO: extract main body inside `_health_check_loop`, then can test that function
+class TestHealthCheck:
+ def test_check_worker_health_success(self, router_factory, mock_worker: MockSGLangServer):
+ router = router_factory()
+ url, healthy = asyncio.run(router._check_worker_health(mock_worker.url))
+ assert url == mock_worker.url
+ assert healthy is True
+
+ def test_check_worker_health_failure(self, router_factory):
+ router = router_factory()
+ url, healthy = asyncio.run(router._check_worker_health("http://127.0.0.1:59999"))
+ assert url == "http://127.0.0.1:59999"
+ assert healthy is False
+
+
+class TestProxyIntegration:
+ def test_proxy_forwards_request(self, router_env: RouterEnv, mock_worker: MockSGLangServer):
+ requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0).raise_for_status()
+
+ payload = {"input_ids": [1, 2, 3], "return_logprob": True}
+ r = requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0)
+ r.raise_for_status()
+
+ assert "text" in r.json()
+ assert len(mock_worker.request_log) == 1
+ assert mock_worker.request_log[0] == payload
+
+ def test_proxy_multi_worker(self, router_env: RouterEnv, mock_worker_factory):
+ worker1, worker2 = mock_worker_factory(), mock_worker_factory()
+ requests.post(f"{router_env.url}/add_worker", params={"url": worker1.url}, timeout=5.0)
+ requests.post(f"{router_env.url}/add_worker", params={"url": worker2.url}, timeout=5.0)
+
+ payload = {"input_ids": [1, 2, 3], "return_logprob": True}
+ for _ in range(4):
+ requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0).raise_for_status()
+
+ all_requests = worker1.request_log + worker2.request_log
+ assert len(all_requests) == 4
+ assert all(req == payload for req in all_requests)
+
+ def test_proxy_health_endpoint(self, router_env: RouterEnv, mock_worker: MockSGLangServer):
+ requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0)
+
+ r = requests.get(f"{router_env.url}/health", timeout=5.0)
+ r.raise_for_status()
+ assert r.json()["status"] == "ok"
diff --git a/tests/fast/router/test_seq_trajectory.py b/tests/fast/router/test_seq_trajectory.py
new file mode 100644
index 000000000..0705996a8
--- /dev/null
+++ b/tests/fast/router/test_seq_trajectory.py
@@ -0,0 +1,366 @@
+from types import SimpleNamespace
+
+import pytest
+from transformers import AutoTokenizer
+
+from miles.rollout.generate_utils.tokenize_utils import tokenize_messages
+from miles.router.session import seq_trajectory
+from miles.utils.chat_message_utils import get_think_token_start
+
+MODEL_NAME = "Qwen/Qwen3-4B"
+TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
+
+
+def _messages(items: list[tuple[str, str]]) -> list[dict[str, str]]:
+ return [{"role": role, "content": content} for role, content in items]
+
+
+def _token_info_from_ids(token_ids: list[int]) -> seq_trajectory.TokenInfo:
+ return seq_trajectory.TokenInfo(
+ tokens=TOKENIZER.convert_ids_to_tokens(token_ids),
+ token_ids=token_ids,
+ log_probs=[0.0] * len(token_ids),
+ loss_mask=[1] * len(token_ids),
+ )
+
+
+def _turn(messages: list[dict[str, str]], prompt_ids: list[int], response_ids: list[int]) -> seq_trajectory.Turn:
+ payload = {
+ "messages": messages,
+ "prompt_tokens": _token_info_from_ids(prompt_ids),
+ "response_tokens": _token_info_from_ids(response_ids),
+ }
+ if hasattr(seq_trajectory.Turn, "model_construct"):
+ return seq_trajectory.Turn.model_construct(**payload)
+ return seq_trajectory.Turn.construct(**payload)
+
+
+def _turn_from_messages(messages: list[dict[str, str]]) -> seq_trajectory.Turn:
+ prompt_token_ids = TOKENIZER.apply_chat_template(
+ messages[:-1],
+ tokenize=True,
+ add_generation_prompt=True,
+ )
+ response_token_ids = TOKENIZER.encode(messages[-1]["content"], add_special_tokens=False)
+ return _turn(messages, prompt_token_ids, response_token_ids)
+
+
+def _assert_prompt_token_info(token_info: seq_trajectory.TokenInfo, expected_token_ids: list[int]) -> None:
+ assert token_info.token_ids == expected_token_ids
+ assert token_info.tokens == TOKENIZER.convert_ids_to_tokens(expected_token_ids)
+ assert token_info.log_probs == [0.0] * len(expected_token_ids)
+ assert token_info.loss_mask == [0] * len(expected_token_ids)
+
+
+def _make_manager(*, cross_turn_token_out: bool, inherit_last_assistant: bool) -> seq_trajectory.SeqTrajectoryManager:
+ args = SimpleNamespace(
+ cross_turn_token_out=cross_turn_token_out,
+ inherit_last_assistant=inherit_last_assistant,
+ )
+ return seq_trajectory.SeqTrajectoryManager(args, TOKENIZER)
+
+
+def test_turn_match_prefix_messages_returns_remaining():
+ messages = _messages([("user", "hi"), ("assistant", "ok"), ("user", "next"), ("assistant", "done")])
+ turn = _turn(messages, [], [])
+
+ remaining = turn.match_prefix_messages_and_return_remaining(messages[:2])
+
+ assert remaining == messages[2:]
+
+
+def test_turn_match_prefix_messages_exact_match_returns_empty():
+ messages = _messages([("user", "hi"), ("assistant", "ok")])
+ turn = _turn(messages, [], [])
+
+ remaining = turn.match_prefix_messages_and_return_remaining(messages)
+
+ assert remaining == []
+
+
+def test_turn_match_prefix_messages_mismatch_returns_none():
+ messages = _messages([("user", "hi"), ("assistant", "ok")])
+ turn = _turn(messages, [], [])
+
+ assert turn.match_prefix_messages_and_return_remaining([{"role": "user", "content": "nope"}]) is None
+ assert (
+ turn.match_prefix_messages_and_return_remaining(messages + [{"role": "assistant", "content": "extra"}]) is None
+ )
+
+
+def test_calc_prompt_tokens_info_multi_turn_cross_turn_disabled_uses_last_turn():
+ trajectory = seq_trajectory.SeqTrajectory()
+ turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ turn2_messages = _messages(
+ [
+ ("system", "sys"),
+ ("user", "u1"),
+ ("assistant", "a1"),
+ ("user", "u2"),
+ ("assistant", "a2"),
+ ]
+ )
+
+ trajectory.insert_new_turn(_turn_from_messages(turn1_messages))
+ trajectory.insert_new_turn(_turn_from_messages(turn2_messages))
+
+ token_info = trajectory.calc_prompt_tokens_info(
+ turn2_messages,
+ TOKENIZER,
+ cross_turn_token_out=False,
+ inherit_last_assistant=True,
+ )
+ expected_token_ids = TOKENIZER.apply_chat_template(turn2_messages, tokenize=True, add_generation_prompt=True)
+ _assert_prompt_token_info(token_info, expected_token_ids)
+
+
+def test_calc_prompt_tokens_info_multi_turn_cross_turn_uses_prefix_suffix():
+ trajectory = seq_trajectory.SeqTrajectory()
+ turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ turn2_messages = _messages(
+ [
+ ("system", "sys"),
+ ("user", "u1"),
+ ("assistant", "a1"),
+ ("user", "u2"),
+ ("assistant", "a2"),
+ ]
+ )
+ turn1 = _turn_from_messages(turn1_messages)
+ trajectory.insert_new_turn(turn1)
+ trajectory.insert_new_turn(_turn_from_messages(turn2_messages))
+
+ input_messages = _messages([("system", "sys")])
+ remain_messages = _messages([("user", "u1"), ("assistant", "a1")])
+
+ token_info = trajectory.calc_prompt_tokens_info(
+ input_messages,
+ TOKENIZER,
+ cross_turn_token_out=True,
+ inherit_last_assistant=False,
+ )
+ expected_new_token_ids = tokenize_messages(remain_messages, TOKENIZER, add_generation_prompt=True)
+ expected_token_ids = turn1.prompt_tokens.token_ids + turn1.response_tokens.token_ids + expected_new_token_ids
+ _assert_prompt_token_info(token_info, expected_token_ids)
+
+
+def test_calc_prompt_tokens_info_multi_turn_cross_turn_matches_two_turns():
+ trajectory = seq_trajectory.SeqTrajectory()
+ turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ turn2_messages = _messages([("user", "u1"), ("assistant", "a1"), ("user", "u2"), ("assistant", "a2")])
+ turn3_messages = _messages([("user", "u3"), ("assistant", "a3")])
+ turn2 = _turn_from_messages(turn2_messages)
+
+ trajectory.insert_new_turn(_turn_from_messages(turn1_messages))
+ trajectory.insert_new_turn(turn2)
+ trajectory.insert_new_turn(_turn_from_messages(turn3_messages))
+
+ input_messages = _messages([("system", "sys")])
+ remain_messages = _messages([("user", "u2"), ("assistant", "a2")])
+
+ token_info = trajectory.calc_prompt_tokens_info(
+ input_messages,
+ TOKENIZER,
+ cross_turn_token_out=True,
+ inherit_last_assistant=False,
+ )
+ expected_new_token_ids = tokenize_messages(remain_messages, TOKENIZER, add_generation_prompt=True)
+ expected_token_ids = turn2.prompt_tokens.token_ids + turn2.response_tokens.token_ids + expected_new_token_ids
+ _assert_prompt_token_info(token_info, expected_token_ids)
+
+
+def test_calc_prompt_tokens_info_multi_turn_cross_turn_empty_remaining_messages():
+ trajectory = seq_trajectory.SeqTrajectory()
+ turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ turn2_messages = _messages(
+ [
+ ("system", "sys"),
+ ("user", "u1"),
+ ("assistant", "a1"),
+ ("user", "u2"),
+ ("assistant", "a2"),
+ ]
+ )
+ turn1 = _turn_from_messages(turn1_messages)
+
+ trajectory.insert_new_turn(turn1)
+ trajectory.insert_new_turn(_turn_from_messages(turn2_messages))
+
+ token_info = trajectory.calc_prompt_tokens_info(
+ turn1_messages,
+ TOKENIZER,
+ cross_turn_token_out=True,
+ inherit_last_assistant=False,
+ )
+ expected_token_ids = turn1.prompt_tokens.token_ids + turn1.response_tokens.token_ids
+ _assert_prompt_token_info(token_info, expected_token_ids)
+
+
+def test_tokenize_messages_trims_complete_think_content():
+ messages_with_think = _messages([("assistant", "thoughtanswer")])
+ messages_plain = _messages([("assistant", "answer")])
+
+ tokens_with_think = tokenize_messages(messages_with_think, TOKENIZER, add_generation_prompt=True)
+ tokens_plain = tokenize_messages(messages_plain, TOKENIZER, add_generation_prompt=True)
+
+ think_start_id = get_think_token_start("qwen3")[1]
+
+ assert tokens_with_think == tokens_plain
+ assert think_start_id not in tokens_with_think
+
+
+def test_tokenize_messages_does_not_trim_incomplete_think_content():
+ messages_incomplete_think = _messages([("assistant", "thought answer")])
+ messages_plain = _messages([("assistant", "answer")])
+
+ tokens_incomplete = tokenize_messages(messages_incomplete_think, TOKENIZER, add_generation_prompt=True)
+ tokens_plain = tokenize_messages(messages_plain, TOKENIZER, add_generation_prompt=True)
+
+ think_start_id = get_think_token_start("qwen3")[1]
+
+ assert tokens_incomplete != tokens_plain
+ assert think_start_id in tokens_incomplete
+
+
+def test_manager_calc_prompt_tokens_missing_session_returns_none():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ messages = _messages([("system", "sys"), ("user", "hi")])
+
+ assert manager.calc_prompt_tokens("missing", messages) is None
+
+
+def test_manager_get_session_by_id_empty_returns_empty_token_info():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ session_id = manager.create_session()
+
+ token_info = manager.get_session_by_id(session_id)
+ assert token_info is not None
+ assert token_info.tokens == []
+ assert token_info.token_ids == []
+ assert token_info.log_probs == []
+ assert token_info.loss_mask == []
+
+
+def test_manager_calc_prompt_tokens_no_turns_retokens_messages():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ session_id = manager.create_session()
+ messages = _messages([("system", "sys"), ("user", "u1")])
+
+ token_info = manager.calc_prompt_tokens(session_id, messages)
+
+ expected_token_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
+ _assert_prompt_token_info(token_info, expected_token_ids)
+
+
+def test_manager_calc_prompt_tokens_inherit_last_assistant_raises():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=True)
+ session_id = manager.create_session()
+ turn_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ manager.add_record(session_id, _turn_from_messages(turn_messages))
+
+ with pytest.raises(NotImplementedError):
+ manager.calc_prompt_tokens(session_id, turn_messages)
+
+
+def test_manager_calc_prompt_tokens_cross_turn_single_turn_uses_tokenize_messages():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ session_id = manager.create_session()
+ turn_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ manager.add_record(session_id, _turn_from_messages(turn_messages))
+
+ messages = _messages([("system", "sys"), ("user", "next")])
+ token_info = manager.calc_prompt_tokens(session_id, messages)
+
+ expected_token_ids = tokenize_messages(messages, TOKENIZER, add_generation_prompt=True)
+ _assert_prompt_token_info(token_info, expected_token_ids)
+
+
+def test_manager_calc_prompt_tokens_cross_turn_multi_turn_prefix_success():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ session_id = manager.create_session()
+ turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ turn2_messages = _messages([("user", "u2"), ("assistant", "a2")])
+ turn1 = _turn_from_messages(turn1_messages)
+ manager.add_record(session_id, turn1)
+ manager.add_record(session_id, _turn_from_messages(turn2_messages))
+
+ input_messages = _messages([("system", "sys")])
+ token_info = manager.calc_prompt_tokens(session_id, input_messages)
+
+ remain_messages = _messages([("user", "u1"), ("assistant", "a1")])
+ expected_new_token_ids = tokenize_messages(remain_messages, TOKENIZER, add_generation_prompt=True)
+ expected_token_ids = turn1.prompt_tokens.token_ids + turn1.response_tokens.token_ids + expected_new_token_ids
+ _assert_prompt_token_info(token_info, expected_token_ids)
+
+
+def test_manager_calc_prompt_tokens_cross_turn_multi_turn_prefix_mismatch_raises():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ session_id = manager.create_session()
+ turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ manager.add_record(session_id, _turn_from_messages(turn1_messages))
+ manager.add_record(session_id, _turn_from_messages(_messages([("user", "u2"), ("assistant", "a2")])))
+
+ with pytest.raises(ValueError):
+ manager.calc_prompt_tokens(session_id, _messages([("system", "nope")]))
+
+
+def test_manager_calc_prompt_tokens_cross_turn_disabled_retokens_messages():
+ manager = _make_manager(cross_turn_token_out=False, inherit_last_assistant=True)
+ session_id = manager.create_session()
+ manager.add_record(
+ session_id, _turn_from_messages(_messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]))
+ )
+
+ messages = _messages([("system", "sys"), ("user", "new")])
+ token_info = manager.calc_prompt_tokens(session_id, messages)
+
+ expected_token_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
+ _assert_prompt_token_info(token_info, expected_token_ids)
+
+
+def test_manager_get_session_by_id_after_add_record_returns_combined_tokens():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ session_id = manager.create_session()
+ messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ turn = _turn_from_messages(messages)
+ manager.add_record(session_id, turn)
+
+ token_info = manager.get_session_by_id(session_id)
+
+ expected_token_ids = turn.prompt_tokens.token_ids + turn.response_tokens.token_ids
+ assert token_info.token_ids == expected_token_ids
+ assert token_info.tokens == TOKENIZER.convert_ids_to_tokens(expected_token_ids)
+ assert token_info.log_probs == turn.prompt_tokens.log_probs + turn.response_tokens.log_probs
+ assert token_info.loss_mask == turn.prompt_tokens.loss_mask + turn.response_tokens.loss_mask
+
+
+def test_manager_delete_session_by_id():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ session_id = manager.create_session()
+
+ assert manager.delete_session_by_id(session_id) is True
+ assert manager.delete_session_by_id(session_id) is False
+
+
+def test_manager_add_record_missing_session_raises():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ turn = _turn_from_messages(messages)
+
+ with pytest.raises(ValueError):
+ manager.add_record("missing", turn)
+
+
+def test_manager_calc_prompt_tokens_cross_turn_multi_turn_empty_remaining_messages():
+ manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False)
+ session_id = manager.create_session()
+ turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])
+ turn2_messages = _messages([("user", "u2"), ("assistant", "a2")])
+ turn1 = _turn_from_messages(turn1_messages)
+ manager.add_record(session_id, turn1)
+ manager.add_record(session_id, _turn_from_messages(turn2_messages))
+
+ token_info = manager.calc_prompt_tokens(session_id, turn1_messages)
+
+ expected_token_ids = turn1.prompt_tokens.token_ids + turn1.response_tokens.token_ids
+ _assert_prompt_token_info(token_info, expected_token_ids)
diff --git a/tests/fast/router/test_sessions.py b/tests/fast/router/test_sessions.py
new file mode 100644
index 000000000..3ab179fde
--- /dev/null
+++ b/tests/fast/router/test_sessions.py
@@ -0,0 +1,314 @@
+from types import SimpleNamespace
+
+import pytest
+import requests
+from transformers import AutoTokenizer
+
+from miles.rollout.generate_utils.tokenize_utils import tokenize_messages
+from miles.router.router import MilesRouter
+from miles.utils.http_utils import find_available_port
+from miles.utils.test_utils.mock_sglang_server import ProcessResult, with_mock_server
+from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer
+
+MODEL_NAME = "Qwen/Qwen3-0.6B"
+TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
+
+
+@pytest.fixture(scope="module")
+def router_env():
+ def process_fn(_prompt: str) -> ProcessResult:
+ return ProcessResult(text="ok", finish_reason="stop")
+
+ with with_mock_server(model_name=MODEL_NAME, process_fn=process_fn) as backend:
+ args = SimpleNamespace(
+ miles_router_max_connections=10,
+ miles_router_timeout=30,
+ miles_router_middleware_paths=[],
+ rollout_health_check_interval=60,
+ miles_router_health_check_failure_threshold=3,
+ hf_checkpoint=MODEL_NAME,
+ cross_turn_token_out=False,
+ inherit_last_assistant=False,
+ )
+ router = MilesRouter(args)
+
+ port = find_available_port(31000)
+ server = UvicornThreadServer(router.app, host="127.0.0.1", port=port)
+ server.start()
+
+ url = f"http://127.0.0.1:{port}"
+ requests.post(f"{url}/add_worker", json={"url": backend.url})
+
+ try:
+ yield {"url": url, "backend": backend}
+ finally:
+ server.stop()
+
+
+def _create_session(url: str) -> str:
+ response = requests.post(f"{url}/sessions")
+ assert response.status_code == 200
+ return response.json()["session_id"]
+
+
+def _extract_response_tokens(response_body: dict) -> tuple[list[int], list[float], list[str]]:
+ logprobs_content = response_body["choices"][0]["logprobs"]["content"]
+ token_ids = [item.get("token_id", TOKENIZER.convert_tokens_to_ids(item["token"])) for item in logprobs_content]
+ logprobs = [item["logprob"] for item in logprobs_content]
+ tokens = [item["token"] for item in logprobs_content]
+ return token_ids, logprobs, tokens
+
+
+def test_create_session_and_get_empty_records(router_env):
+ url = router_env["url"]
+ session_id = _create_session(url)
+
+ response = requests.get(f"{url}/sessions/{session_id}")
+ assert response.status_code == 200
+
+ data = response.json()
+ assert data["session_id"] == session_id
+ assert data["records"] == {
+ "tokens": [],
+ "token_ids": [],
+ "log_probs": [],
+ "loss_mask": [],
+ }
+
+
+def test_get_session_not_found(router_env):
+ url = router_env["url"]
+ response = requests.get(f"{url}/sessions/nonexistent")
+ assert response.status_code == 404
+ assert response.json()["error"] == "session not found"
+
+
+def test_delete_session(router_env):
+ url = router_env["url"]
+ session_id = _create_session(url)
+
+ delete_resp = requests.delete(f"{url}/sessions/{session_id}")
+ assert delete_resp.status_code == 204
+ assert delete_resp.text == ""
+
+ missing_resp = requests.delete(f"{url}/sessions/{session_id}")
+ assert missing_resp.status_code == 404
+ assert missing_resp.json()["error"] == "session not found"
+
+
+def test_proxy_session_not_found(router_env):
+ url = router_env["url"]
+ response = requests.post(
+ f"{url}/sessions/nonexistent/v1/chat/completions",
+ json={"messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert response.status_code == 404
+ assert response.json()["error"] == "session not found"
+
+
+def test_proxy_inserts_input_ids_and_records_tokens(router_env):
+ url = router_env["url"]
+ backend = router_env["backend"]
+ backend.reset_stats()
+
+ session_id = _create_session(url)
+ messages = [
+ {"role": "system", "content": "sys"},
+ {"role": "user", "content": "hi"},
+ ]
+
+ response = requests.post(
+ f"{url}/sessions/{session_id}/v1/chat/completions",
+ json={"messages": messages},
+ )
+ assert response.status_code == 200
+
+ response_body = response.json()
+
+ expected_prompt_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
+ backend_payload = backend.request_log[-1]
+ assert backend_payload["input_ids"] == expected_prompt_ids
+
+ response_token_ids, response_logprobs, response_tokens = _extract_response_tokens(response_body)
+
+ get_resp = requests.get(f"{url}/sessions/{session_id}")
+ assert get_resp.status_code == 200
+
+ records = get_resp.json()["records"]
+ expected_token_ids = expected_prompt_ids + response_token_ids
+ assert records["token_ids"] == expected_token_ids
+ assert records["tokens"] == TOKENIZER.convert_ids_to_tokens(expected_prompt_ids) + response_tokens
+ assert records["log_probs"] == [0.0] * len(expected_prompt_ids) + response_logprobs
+ assert records["loss_mask"] == [0] * len(expected_prompt_ids) + [1] * len(response_token_ids)
+
+
+def test_proxy_preserves_input_ids_when_provided(router_env):
+ url = router_env["url"]
+ backend = router_env["backend"]
+ backend.reset_stats()
+
+ session_id = _create_session(url)
+ messages = [
+ {"role": "system", "content": "sys"},
+ {"role": "user", "content": "hi"},
+ ]
+ base_prompt_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
+ custom_input_ids = base_prompt_ids + [base_prompt_ids[-1]]
+
+ response = requests.post(
+ f"{url}/sessions/{session_id}/v1/chat/completions",
+ json={"messages": messages, "input_ids": custom_input_ids},
+ )
+ assert response.status_code == 200
+
+ backend_payload = backend.request_log[-1]
+ assert backend_payload["input_ids"] == custom_input_ids
+
+ response_body = response.json()
+ response_token_ids, response_logprobs, response_tokens = _extract_response_tokens(response_body)
+
+ get_resp = requests.get(f"{url}/sessions/{session_id}")
+ records = get_resp.json()["records"]
+ assert records["token_ids"] == response_token_ids
+ assert records["tokens"] == response_tokens
+ assert records["log_probs"] == response_logprobs
+ assert records["loss_mask"] == [1] * len(response_token_ids)
+
+
+def test_proxy_multi_turn_second_call_uses_only_new_messages(router_env):
+ url = router_env["url"]
+ backend = router_env["backend"]
+ backend.reset_stats()
+
+ session_id = _create_session(url)
+ messages_turn1 = [
+ {"role": "system", "content": "sys"},
+ {"role": "user", "content": "hi"},
+ ]
+ response1 = requests.post(
+ f"{url}/sessions/{session_id}/v1/chat/completions",
+ json={"messages": messages_turn1},
+ )
+ assert response1.status_code == 200
+
+ messages_turn2 = [{"role": "user", "content": "next"}]
+ response2 = requests.post(
+ f"{url}/sessions/{session_id}/v1/chat/completions",
+ json={"messages": messages_turn2},
+ )
+ assert response2.status_code == 200
+
+ expected_prompt_ids = tokenize_messages(messages_turn2, TOKENIZER, add_generation_prompt=True)
+ backend_payload = backend.request_log[-1]
+ assert backend_payload["input_ids"] == expected_prompt_ids
+
+ response2_body = response2.json()
+ response2_token_ids, response2_logprobs, response2_tokens = _extract_response_tokens(response2_body)
+
+ get_resp = requests.get(f"{url}/sessions/{session_id}")
+ records = get_resp.json()["records"]
+ expected_token_ids = expected_prompt_ids + response2_token_ids
+ assert records["token_ids"] == expected_token_ids
+ assert records["tokens"] == TOKENIZER.convert_ids_to_tokens(expected_prompt_ids) + response2_tokens
+ assert records["log_probs"] == [0.0] * len(expected_prompt_ids) + response2_logprobs
+ assert records["loss_mask"] == [0] * len(expected_prompt_ids) + [1] * len(response2_token_ids)
+
+
+def test_proxy_third_call_reuses_first_turn_prefix(router_env):
+ url = router_env["url"]
+ backend = router_env["backend"]
+ backend.reset_stats()
+
+ session_id = _create_session(url)
+ messages_turn1 = [
+ {"role": "system", "content": "sys"},
+ {"role": "user", "content": "hi"},
+ ]
+ response1 = requests.post(
+ f"{url}/sessions/{session_id}/v1/chat/completions",
+ json={"messages": messages_turn1},
+ )
+ assert response1.status_code == 200
+
+ response1_body = response1.json()
+ response1_token_ids, _, _ = _extract_response_tokens(response1_body)
+ prompt1_ids = TOKENIZER.apply_chat_template(messages_turn1, tokenize=True, add_generation_prompt=True)
+
+ response2 = requests.post(
+ f"{url}/sessions/{session_id}/v1/chat/completions",
+ json={"messages": [{"role": "user", "content": "next"}]},
+ )
+ assert response2.status_code == 200
+
+ assistant_message = response1_body["choices"][0]["message"]
+ messages_turn3 = [{"role": "system", "content": "sys"}]
+ response3 = requests.post(
+ f"{url}/sessions/{session_id}/v1/chat/completions",
+ json={"messages": messages_turn3},
+ )
+ assert response3.status_code == 200
+
+ remain_messages = [messages_turn1[1], assistant_message]
+ expected_prompt_ids = (
+ prompt1_ids
+ + response1_token_ids
+ + tokenize_messages(
+ remain_messages,
+ TOKENIZER,
+ add_generation_prompt=True,
+ )
+ )
+ backend_payload = backend.request_log[-1]
+ assert backend_payload["input_ids"] == expected_prompt_ids
+
+ response3_body = response3.json()
+ response3_token_ids, response3_logprobs, response3_tokens = _extract_response_tokens(response3_body)
+
+ get_resp = requests.get(f"{url}/sessions/{session_id}")
+ records = get_resp.json()["records"]
+ expected_token_ids = expected_prompt_ids + response3_token_ids
+ assert records["token_ids"] == expected_token_ids
+ assert records["tokens"] == TOKENIZER.convert_ids_to_tokens(expected_prompt_ids) + response3_tokens
+ assert records["log_probs"] == [0.0] * len(expected_prompt_ids) + response3_logprobs
+ assert records["loss_mask"] == [0] * len(expected_prompt_ids) + [1] * len(response3_token_ids)
+
+
+def test_proxy_respects_token_id_in_logprobs(router_env):
+ url = router_env["url"]
+ backend = router_env["backend"]
+ backend.reset_stats()
+
+ original_compute = backend._compute_chat_completions_response
+
+ def _custom_compute(payload: dict) -> dict:
+ response = original_compute(payload)
+ for idx, item in enumerate(response["choices"][0]["logprobs"]["content"]):
+ item["token_id"] = 900000 + idx
+ return response
+
+ backend._compute_chat_completions_response = _custom_compute
+ try:
+ session_id = _create_session(url)
+ messages = [
+ {"role": "system", "content": "sys"},
+ {"role": "user", "content": "hi"},
+ ]
+ response = requests.post(
+ f"{url}/sessions/{session_id}/v1/chat/completions",
+ json={"messages": messages},
+ )
+ assert response.status_code == 200
+
+ response_body = response.json()
+ custom_ids, response_logprobs, response_tokens = _extract_response_tokens(response_body)
+ prompt_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
+
+ get_resp = requests.get(f"{url}/sessions/{session_id}")
+ records = get_resp.json()["records"]
+ expected_token_ids = prompt_ids + custom_ids
+ assert records["token_ids"] == expected_token_ids
+ assert records["tokens"] == TOKENIZER.convert_ids_to_tokens(prompt_ids) + response_tokens
+ assert records["log_probs"] == [0.0] * len(prompt_ids) + response_logprobs
+ assert records["loss_mask"] == [0] * len(prompt_ids) + [1] * len(custom_ids)
+ finally:
+ backend._compute_chat_completions_response = original_compute
diff --git a/tests/fast/utils/__init__.py b/tests/fast/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/utils/test_arguments.py b/tests/fast/utils/test_arguments.py
new file mode 100644
index 000000000..9bd1a620d
--- /dev/null
+++ b/tests/fast/utils/test_arguments.py
@@ -0,0 +1,58 @@
+import argparse
+import sys
+from unittest.mock import patch
+
+import pytest
+
+from miles.utils.arguments import get_miles_extra_args_provider
+from miles.utils.misc import function_registry
+
+PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"]
+REQUIRED_ARGS = ["--rollout-batch-size", "64"]
+
+
+def make_class_with_add_arguments():
+ class MyFn:
+ @classmethod
+ def add_arguments(cls, parser):
+ parser.add_argument("--my-custom-arg", type=int, default=42)
+
+ return MyFn
+
+
+def make_function_with_add_arguments():
+ def my_fn():
+ pass
+
+ my_fn.add_arguments = lambda parser: parser.add_argument("--my-custom-arg", type=int, default=42)
+ return my_fn
+
+
+def make_function_without_add_arguments():
+ def my_fn():
+ pass
+
+ return my_fn
+
+
+@pytest.mark.parametrize("path_arg", PATH_ARGS)
+class TestAddArgumentsSupport:
+
+ @pytest.mark.parametrize("fn_factory", [make_class_with_add_arguments, make_function_with_add_arguments])
+ def test_add_arguments_is_called_and_arg_is_parsed(self, path_arg, fn_factory):
+ fn = fn_factory()
+ with function_registry.temporary("test:fn", fn), patch.object(
+ sys, "argv", ["test", path_arg, "test:fn", "--my-custom-arg", "100"] + REQUIRED_ARGS
+ ):
+ parser = argparse.ArgumentParser()
+ get_miles_extra_args_provider()(parser)
+ args, _ = parser.parse_known_args()
+ assert args.my_custom_arg == 100
+
+ def test_skips_function_without_add_arguments(self, path_arg):
+ fn = make_function_without_add_arguments()
+ with function_registry.temporary("test:fn", fn), patch.object(
+ sys, "argv", ["test", path_arg, "test:fn"] + REQUIRED_ARGS
+ ):
+ parser = argparse.ArgumentParser()
+ get_miles_extra_args_provider()(parser)
diff --git a/tests/utils/test_mask_utils.py b/tests/fast/utils/test_mask_utils.py
similarity index 100%
rename from tests/utils/test_mask_utils.py
rename to tests/fast/utils/test_mask_utils.py
diff --git a/tests/fast/utils/test_misc.py b/tests/fast/utils/test_misc.py
new file mode 100644
index 000000000..810c2b67c
--- /dev/null
+++ b/tests/fast/utils/test_misc.py
@@ -0,0 +1,59 @@
+import os
+
+import pytest
+
+from miles.utils.misc import FunctionRegistry, function_registry, load_function
+
+
+def _fn_a():
+ return "a"
+
+
+def _fn_b():
+ return "b"
+
+
+class TestFunctionRegistry:
+ def test_register_and_get(self):
+ registry = FunctionRegistry()
+ with registry.temporary("my_fn", _fn_a):
+ assert registry.get("my_fn") is _fn_a
+
+ def test_register_duplicate_raises(self):
+ registry = FunctionRegistry()
+ with registry.temporary("my_fn", _fn_a):
+ with pytest.raises(AssertionError):
+ with registry.temporary("my_fn", _fn_b):
+ pass
+
+ def test_unregister(self):
+ registry = FunctionRegistry()
+ with registry.temporary("my_fn", _fn_a):
+ assert registry.get("my_fn") is _fn_a
+ assert registry.get("my_fn") is None
+
+ def test_temporary_cleanup_on_exception(self):
+ registry = FunctionRegistry()
+ with pytest.raises(RuntimeError):
+ with registry.temporary("temp_fn", _fn_a):
+ raise RuntimeError("test")
+ assert registry.get("temp_fn") is None
+
+
+class TestLoadFunction:
+ def test_load_from_module(self):
+ import os.path
+
+ assert load_function("os.path.join") is os.path.join
+
+ def test_load_none_returns_none(self):
+ assert load_function(None) is None
+
+ def test_load_from_registry(self):
+ with function_registry.temporary("test:my_fn", _fn_a):
+ assert load_function("test:my_fn") is _fn_a
+
+ def test_registry_takes_precedence(self):
+ with function_registry.temporary("os.path.join", _fn_b):
+ assert load_function("os.path.join") is _fn_b
+ assert load_function("os.path.join") is os.path.join
diff --git a/tests/fast/utils/test_utils/__init__.py b/tests/fast/utils/test_utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/fast/utils/test_utils/test_mock_sglang_server.py b/tests/fast/utils/test_utils/test_mock_sglang_server.py
new file mode 100644
index 000000000..6633678da
--- /dev/null
+++ b/tests/fast/utils/test_utils/test_mock_sglang_server.py
@@ -0,0 +1,409 @@
+import asyncio
+import concurrent.futures
+import time
+
+import pytest
+import requests
+
+from miles.utils.test_utils.mock_sglang_server import (
+ Counter,
+ ProcessResult,
+ ProcessResultMetaInfo,
+ default_process_fn,
+ with_mock_server,
+)
+from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub
+
+
+def expected_logprobs(tokenizer, text: str) -> list[dict]:
+ output_ids = tokenizer.encode(text, add_special_tokens=False)
+ return [{"token": tokenizer.convert_ids_to_tokens(tid), "logprob": -i / 128} for i, tid in enumerate(output_ids)]
+
+
+@pytest.fixture(scope="module")
+def mock_server():
+ with with_mock_server() as server:
+ yield server
+
+
+class TestProcessResultMetaInfo:
+ def test_to_dict_empty(self):
+ assert ProcessResultMetaInfo().to_dict() == {}
+
+ def test_to_dict_single_field(self):
+ assert ProcessResultMetaInfo(weight_version="v1").to_dict() == {"weight_version": "v1"}
+
+ def test_to_dict_partial_fields(self):
+ assert ProcessResultMetaInfo(weight_version="v1", spec_accept_token_num=10).to_dict() == {
+ "weight_version": "v1",
+ "spec_accept_token_num": 10,
+ }
+
+ def test_to_dict_all_fields(self):
+ assert ProcessResultMetaInfo(
+ weight_version="v1",
+ routed_experts="abc",
+ spec_accept_token_num=10,
+ spec_draft_token_num=15,
+ spec_verify_ct=3,
+ ).to_dict() == {
+ "weight_version": "v1",
+ "routed_experts": "abc",
+ "spec_accept_token_num": 10,
+ "spec_draft_token_num": 15,
+ "spec_verify_ct": 3,
+ }
+
+
+class TestDefaultProcessFn:
+ def test_math_question(self):
+ assert default_process_fn("What is 1+5?") == ProcessResult(text="\\boxed{6}", finish_reason="stop")
+ assert default_process_fn("What is 1+10?") == ProcessResult(text="\\boxed{11}", finish_reason="stop")
+
+ def test_unknown_question(self):
+ assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop")
+
+
+class TestCounter:
+ def test_tracks_max(self):
+ counter = Counter()
+ assert counter.max_value == 0
+
+ with counter.track():
+ assert counter.max_value == 1
+ with counter.track():
+ assert counter.max_value == 2
+
+ counter.reset()
+ assert counter.max_value == 0
+
+ def test_concurrent_tasks(self):
+ counter = Counter()
+
+ async def task():
+ with counter.track():
+ await asyncio.sleep(0.1)
+
+ async def run_all():
+ await asyncio.gather(task(), task(), task())
+
+ asyncio.run(run_all())
+ assert counter.max_value == 3
+
+
+class TestMockServerBasic:
+ def test_start_stop(self, mock_server):
+ assert mock_server.port > 0
+ assert f"http://{mock_server.host}:{mock_server.port}" == mock_server.url
+
+ def test_request_log_and_reset_stats(self, mock_server):
+ mock_server.reset_stats()
+ assert len(mock_server.request_log) == 0
+
+ payload = {"input_ids": [1, 2, 3], "sampling_params": {"temperature": 0.5}, "return_logprob": True}
+ requests.post(f"{mock_server.url}/generate", json=payload, timeout=5.0)
+ assert len(mock_server.request_log) == 1
+ assert mock_server.request_log[0] == payload
+
+ mock_server.reset_stats()
+ assert len(mock_server.request_log) == 0
+ assert mock_server.max_concurrent == 0
+
+ @pytest.mark.parametrize("latency,min_time,max_time", [(0.0, 0.0, 0.3), (0.5, 0.5, 1.0)])
+ def test_latency(self, latency, min_time, max_time):
+ with with_mock_server(latency=latency) as server:
+ start = time.time()
+ requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0)
+ elapsed = time.time() - start
+ assert min_time <= elapsed < max_time
+
+ def test_max_concurrent_with_latency(self):
+ with with_mock_server(latency=0.1) as server:
+
+ def send_request():
+ requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
+ futures = [executor.submit(send_request) for _ in range(3)]
+ concurrent.futures.wait(futures)
+
+ assert server.max_concurrent == 3
+
+ def test_health_endpoint(self, mock_server):
+ response = requests.get(f"{mock_server.url}/health", timeout=5.0)
+ assert response.status_code == 200
+ assert response.json() == {"status": "ok"}
+
+ def test_abort_request_endpoint(self, mock_server):
+ response = requests.post(f"{mock_server.url}/abort_request", json={}, timeout=5.0)
+ assert response.status_code == 200
+ assert response.json() == {"status": "ok"}
+
+
+class TestGenerateEndpoint:
+ def test_basic(self, mock_server):
+ prompt = "What is 1+7?"
+ input_ids = mock_server.tokenizer.encode(prompt, add_special_tokens=False)
+ assert input_ids == [3838, 374, 220, 16, 10, 22, 30]
+
+ response = requests.post(
+ f"{mock_server.url}/generate",
+ json={
+ "input_ids": input_ids,
+ "sampling_params": {"temperature": 0.7, "max_new_tokens": 10},
+ "return_logprob": True,
+ },
+ timeout=5.0,
+ )
+ assert response.status_code == 200
+ assert response.json() == {
+ "text": "\\boxed{8}",
+ "meta_info": {
+ "finish_reason": {"type": "stop"},
+ "prompt_tokens": len(input_ids),
+ "cached_tokens": 0,
+ "completion_tokens": 5,
+ "output_token_logprobs": [
+ [-0.0, 59],
+ [-0.0078125, 79075],
+ [-0.015625, 90],
+ [-0.0234375, 23],
+ [-0.03125, 92],
+ ],
+ },
+ }
+
+ def test_with_meta_info(self):
+ def process_fn(_: str) -> ProcessResult:
+ return ProcessResult(
+ text="ok",
+ finish_reason="stop",
+ cached_tokens=5,
+ meta_info=ProcessResultMetaInfo(
+ weight_version="v2.0",
+ routed_experts="encoded_data",
+ spec_accept_token_num=10,
+ spec_draft_token_num=15,
+ spec_verify_ct=3,
+ ),
+ )
+
+ with with_mock_server(process_fn=process_fn) as server:
+ response = requests.post(
+ f"{server.url}/generate",
+ json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True},
+ timeout=5.0,
+ )
+
+ assert response.json() == {
+ "text": "ok",
+ "meta_info": {
+ "finish_reason": {"type": "stop"},
+ "prompt_tokens": 3,
+ "cached_tokens": 5,
+ "completion_tokens": 1,
+ "output_token_logprobs": [[-0.0, 562]],
+ "weight_version": "v2.0",
+ "routed_experts": "encoded_data",
+ "spec_accept_token_num": 10,
+ "spec_draft_token_num": 15,
+ "spec_verify_ct": 3,
+ },
+ }
+
+ def test_finish_reason_length(self):
+ def process_fn(_: str) -> ProcessResult:
+ return ProcessResult(text="truncated output", finish_reason="length")
+
+ with with_mock_server(process_fn=process_fn) as server:
+ response = requests.post(
+ f"{server.url}/generate",
+ json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True},
+ timeout=5.0,
+ )
+ data = response.json()
+
+ finish_reason = data["meta_info"]["finish_reason"]
+ assert finish_reason["type"] == "length"
+ assert finish_reason["length"] == data["meta_info"]["completion_tokens"]
+
+
+class TestChatCompletionsEndpoint:
+ def test_basic(self, mock_server):
+ response = requests.post(
+ f"{mock_server.url}/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "What is 1+5?"}],
+ },
+ timeout=5.0,
+ )
+ assert response.status_code == 200
+ data = response.json()
+
+ assert data["id"].startswith("chatcmpl-")
+ assert isinstance(data["created"], int)
+ assert data == {
+ "id": data["id"],
+ "object": "chat.completion",
+ "created": data["created"],
+ "model": "mock-model",
+ "choices": [
+ {
+ "index": 0,
+ "message": {"role": "assistant", "content": "\\boxed{6}", "tool_calls": None},
+ "logprobs": {"content": expected_logprobs(mock_server.tokenizer, "\\boxed{6}")},
+ "finish_reason": "stop",
+ }
+ ],
+ }
+
+ def test_with_tool_calls(self):
+ tool_call_response = 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n'
+
+ def process_fn(_: str) -> ProcessResult:
+ return ProcessResult(text=tool_call_response, finish_reason="stop")
+
+ with with_mock_server(process_fn=process_fn) as server:
+ response = requests.post(
+ f"{server.url}/v1/chat/completions",
+ json={
+ "model": "test",
+ "messages": [{"role": "user", "content": "What year is it?"}],
+ "tools": SAMPLE_TOOLS,
+ },
+ timeout=5.0,
+ )
+ data = response.json()
+
+ assert data["choices"][0] == {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "Let me check for you.",
+ "tool_calls": [
+ {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}}
+ ],
+ },
+ "logprobs": {"content": expected_logprobs(server.tokenizer, tool_call_response)},
+ "finish_reason": "tool_calls",
+ }
+
+ def test_with_tools_but_no_tool_call(self):
+ response_text = "The weather is sunny today."
+
+ def process_fn(_: str) -> ProcessResult:
+ return ProcessResult(text=response_text, finish_reason="stop")
+
+ with with_mock_server(process_fn=process_fn) as server:
+ response = requests.post(
+ f"{server.url}/v1/chat/completions",
+ json={
+ "model": "test",
+ "messages": [{"role": "user", "content": "What's the weather?"}],
+ "tools": SAMPLE_TOOLS,
+ },
+ timeout=5.0,
+ )
+ data = response.json()
+
+ assert data["choices"][0] == {
+ "index": 0,
+ "message": {"role": "assistant", "content": response_text, "tool_calls": None},
+ "logprobs": {"content": expected_logprobs(server.tokenizer, response_text)},
+ "finish_reason": "stop",
+ }
+
+ def test_with_multiple_tool_calls(self):
+ multi_tool_response = (
+ "I will get year and temperature.\n"
+ '\n{"name": "get_year", "arguments": {}}\n\n'
+ '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n'
+ )
+
+ def process_fn(_: str) -> ProcessResult:
+ return ProcessResult(text=multi_tool_response, finish_reason="stop")
+
+ with with_mock_server(process_fn=process_fn) as server:
+ response = requests.post(
+ f"{server.url}/v1/chat/completions",
+ json={
+ "model": "test",
+ "messages": [{"role": "user", "content": "What year and temperature?"}],
+ "tools": SAMPLE_TOOLS,
+ },
+ timeout=5.0,
+ )
+ data = response.json()
+
+ assert data["choices"][0] == {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "I will get year and temperature.",
+ "tool_calls": [
+ {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}},
+ {
+ "id": "call00001",
+ "type": "function",
+ "function": {"name": "get_temperature", "arguments": '{"location": "Shanghai"}'},
+ },
+ ],
+ },
+ "logprobs": {"content": expected_logprobs(server.tokenizer, multi_tool_response)},
+ "finish_reason": "tool_calls",
+ }
+
+
+class TestMultiTurnToolCallProcessFn:
+ @pytest.mark.parametrize(
+ "prompt,expected_response",
+ [
+ pytest.param(TwoTurnStub.FIRST_PROMPT, TwoTurnStub.FIRST_RESPONSE, id="first_turn"),
+ pytest.param(TwoTurnStub.SECOND_PROMPT, TwoTurnStub.SECOND_RESPONSE, id="second_turn"),
+ ],
+ )
+ def test_generate_endpoint(self, prompt, expected_response):
+ with with_mock_server(process_fn=TwoTurnStub.process_fn) as server:
+ input_ids = server.tokenizer.encode(prompt, add_special_tokens=False)
+ response = requests.post(
+ f"{server.url}/generate",
+ json={"input_ids": input_ids, "sampling_params": {}, "return_logprob": True},
+ timeout=5.0,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["text"] == expected_response
+ assert data["meta_info"]["finish_reason"] == {"type": "stop"}
+
+ @pytest.mark.parametrize(
+ "messages,expected_content,expected_tool_calls,expected_finish_reason",
+ [
+ pytest.param(
+ TwoTurnStub.OPENAI_MESSAGES_FIRST_TURN,
+ TwoTurnStub.FIRST_RESPONSE_CONTENT,
+ TwoTurnStub.FIRST_TOOL_CALLS_OPENAI_FORMAT,
+ "tool_calls",
+ id="first_turn",
+ ),
+ pytest.param(
+ TwoTurnStub.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT,
+ TwoTurnStub.SECOND_RESPONSE,
+ None,
+ "stop",
+ id="second_turn",
+ ),
+ ],
+ )
+ def test_chat_completions_endpoint(self, messages, expected_content, expected_tool_calls, expected_finish_reason):
+ with with_mock_server(process_fn=TwoTurnStub.process_fn) as server:
+ response = requests.post(
+ f"{server.url}/v1/chat/completions",
+ json={"model": "test", "messages": messages, "tools": SAMPLE_TOOLS},
+ timeout=5.0,
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert data["choices"][0]["message"]["content"] == expected_content
+ assert data["choices"][0]["message"]["tool_calls"] == expected_tool_calls
+ assert data["choices"][0]["finish_reason"] == expected_finish_reason
diff --git a/tests/fast/utils/test_utils/test_mock_tools.py b/tests/fast/utils/test_utils/test_mock_tools.py
new file mode 100644
index 000000000..3f2116ec0
--- /dev/null
+++ b/tests/fast/utils/test_utils/test_mock_tools.py
@@ -0,0 +1,111 @@
+import asyncio
+
+import pytest
+from pydantic import TypeAdapter
+from sglang.srt.entrypoints.openai.protocol import Tool
+from sglang.srt.function_call.core_types import ToolCallItem
+from sglang.srt.function_call.function_call_parser import FunctionCallParser
+
+from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub, execute_tool_call
+
+
+class TestExecuteToolCall:
+ def test_execute_get_year(self):
+ result = asyncio.run(execute_tool_call("get_year", {}))
+ assert result == '{"year": 2026}'
+
+ def test_execute_get_temperature(self):
+ result = asyncio.run(execute_tool_call("get_temperature", {"location": "Mars"}))
+ assert result == '{"temperature": -60}'
+
+
+class TestApplyChatTemplateWithTools:
+ EXPECTED_PROMPT_WITHOUT_TOOLS = (
+ "<|im_start|>user\n" "What's the weather in Paris?<|im_end|>\n" "<|im_start|>assistant\n"
+ )
+
+ EXPECTED_PROMPT_WITH_TOOLS = (
+ "<|im_start|>system\n"
+ "# Tools\n\n"
+ "You may call one or more functions to assist with the user query.\n\n"
+ "You are provided with function signatures within XML tags:\n"
+ "\n"
+ '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n'
+ '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n'
+ "\n\n"
+ "For each function call, return a json object with function name and arguments within XML tags:\n"
+ "\n"
+ '{"name": , "arguments": }\n'
+ "<|im_end|>\n"
+ "<|im_start|>user\n"
+ "What's the weather in Paris?<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+
+ @pytest.mark.parametrize(
+ "tools,expected",
+ [
+ pytest.param(None, EXPECTED_PROMPT_WITHOUT_TOOLS, id="without_tools"),
+ pytest.param(SAMPLE_TOOLS, EXPECTED_PROMPT_WITH_TOOLS, id="with_tools"),
+ ],
+ )
+ def test_apply_chat_template(self, tools, expected):
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)
+ messages = [{"role": "user", "content": "What's the weather in Paris?"}]
+
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools)
+
+ assert prompt == expected
+
+
+class TestSGLangFunctionCallParser:
+ """Test to demonstrate and ensure SGLang function call parser have features we need without breaking changes."""
+
+ @pytest.mark.parametrize(
+ "model_output,expected",
+ [
+ pytest.param(
+ 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n',
+ (
+ "Let me check for you.",
+ [ToolCallItem(tool_index=-1, name="get_year", parameters="{}")],
+ ),
+ id="single_tool_call",
+ ),
+ pytest.param(
+ "I will get year and temperature.\n"
+ '\n{"name": "get_year", "arguments": {}}\n\n'
+ '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n',
+ (
+ "I will get year and temperature.",
+ [
+ ToolCallItem(tool_index=-1, name="get_year", parameters="{}"),
+ ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Shanghai"}'),
+ ],
+ ),
+ id="multi_tool_calls",
+ ),
+ pytest.param(
+ "The weather is sunny today.",
+ ("The weather is sunny today.", []),
+ id="no_tool_call",
+ ),
+ pytest.param(
+ TwoTurnStub.FIRST_RESPONSE,
+ (
+ "Let me get the year and temperature first.",
+ [
+ ToolCallItem(tool_index=-1, name="get_year", parameters="{}"),
+ ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Mars"}'),
+ ],
+ ),
+ id="multi_turn_first_response",
+ ),
+ ],
+ )
+ def test_parse_non_stream(self, model_output, expected):
+ tools = TypeAdapter(list[Tool]).validate_python(SAMPLE_TOOLS)
+ parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25")
+ assert parser.parse_non_stream(model_output) == expected
diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py
index c5c0838c5..9b6e69c29 100644
--- a/tests/test_external_rollout.py
+++ b/tests/test_external_rollout.py
@@ -126,6 +126,7 @@ def execute():
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
before_ray_job_submit=_launch_background,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_mimo_7B_mtp_only_grad.py b/tests/test_mimo_7B_mtp_only_grad.py
index 97c76ace5..d90a2d7a7 100644
--- a/tests/test_mimo_7B_mtp_only_grad.py
+++ b/tests/test_mimo_7B_mtp_only_grad.py
@@ -135,6 +135,7 @@ def execute():
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_moonlight_16B_A3B.py b/tests/test_moonlight_16B_A3B.py
index b1255982e..c35943ec1 100644
--- a/tests/test_moonlight_16B_A3B.py
+++ b/tests/test_moonlight_16B_A3B.py
@@ -113,6 +113,7 @@ def execute():
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py
index 15ca8ce5f..ae3c383ae 100644
--- a/tests/test_quick_start_glm4_9B.py
+++ b/tests/test_quick_start_glm4_9B.py
@@ -115,6 +115,7 @@ def execute():
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py
index dcdbd5834..4d7f034f6 100644
--- a/tests/test_qwen2.5_0.5B_gsm8k.py
+++ b/tests/test_qwen2.5_0.5B_gsm8k.py
@@ -120,6 +120,7 @@ def execute():
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py
index dcaaf5e1f..32b60f593 100644
--- a/tests/test_qwen2.5_0.5B_gsm8k_async.py
+++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py
@@ -120,6 +120,7 @@ def execute():
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
train_script="train_async.py",
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py
index d55262cd0..8ce1988de 100644
--- a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py
+++ b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py
@@ -118,6 +118,7 @@ def execute():
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
train_script="train_async.py",
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen2.5_0.5B_gsm8k_short.py b/tests/test_qwen2.5_0.5B_gsm8k_short.py
index afbffbc56..87edf266f 100644
--- a/tests/test_qwen2.5_0.5B_gsm8k_short.py
+++ b/tests/test_qwen2.5_0.5B_gsm8k_short.py
@@ -117,6 +117,7 @@ def execute():
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py
index 3d19b48ce..3d4768e42 100644
--- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py
+++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py
@@ -93,6 +93,7 @@ def execute():
train_args=train_args,
num_gpus_per_node=2,
megatron_model_type=None,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py
index 3d70f3e4c..fcd777288 100644
--- a/tests/test_qwen3_0.6B_fsdp_distributed.py
+++ b/tests/test_qwen3_0.6B_fsdp_distributed.py
@@ -95,6 +95,7 @@ def execute():
num_gpus_per_node=2 if FEW_GPU else 4,
megatron_model_type=None,
train_script="train_async.py",
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/test_qwen3_0.6B_megatron_fsdp_align.py
index 1431d8c3d..b89a2f283 100644
--- a/tests/test_qwen3_0.6B_megatron_fsdp_align.py
+++ b/tests/test_qwen3_0.6B_megatron_fsdp_align.py
@@ -97,6 +97,7 @@ def execute():
train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "),
num_gpus_per_node=NUM_GPUS,
megatron_model_type=None,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
U.execute_train(
@@ -109,6 +110,7 @@ def execute():
),
num_gpus_per_node=NUM_GPUS,
megatron_model_type=None,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
U.execute_train(
@@ -135,6 +137,7 @@ def execute():
"--debug-train-only "
),
num_gpus_per_node=NUM_GPUS,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
megatron_model_type=MODEL_TYPE,
)
diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/test_qwen3_0.6B_parallel_check.py
index 44f5c42fa..d0ad283d1 100644
--- a/tests/test_qwen3_0.6B_parallel_check.py
+++ b/tests/test_qwen3_0.6B_parallel_check.py
@@ -95,6 +95,7 @@ def execute():
),
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
# 8 GPU CPU 1
for num_gpus in [8, 4, 2]:
@@ -124,6 +125,7 @@ def execute():
train_args=args,
num_gpus_per_node=num_gpus,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
train_args += "--calculate-per-token-loss "
diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py
index adff10804..b30eeed8e 100644
--- a/tests/test_qwen3_30B_A3B.py
+++ b/tests/test_qwen3_30B_A3B.py
@@ -139,6 +139,7 @@ def execute():
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/test_qwen3_4B_ckpt.py
index 22fb2b5fc..0df4492e1 100644
--- a/tests/test_qwen3_4B_ckpt.py
+++ b/tests/test_qwen3_4B_ckpt.py
@@ -124,6 +124,7 @@ def execute(mode: str = ""):
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen3_4B_fsdp_true_on_policy.py b/tests/test_qwen3_4B_fsdp_true_on_policy.py
index 7c975c7cc..03ba4094e 100644
--- a/tests/test_qwen3_4B_fsdp_true_on_policy.py
+++ b/tests/test_qwen3_4B_fsdp_true_on_policy.py
@@ -95,6 +95,7 @@ def execute():
"NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0",
"CUBLAS_WORKSPACE_CONFIG": ":4096:8",
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
+ "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1",
}
U.execute_train(
diff --git a/tests/test_qwen3_4B_ppo.py b/tests/test_qwen3_4B_ppo.py
index 962f610fa..d4c1ac273 100644
--- a/tests/test_qwen3_4B_ppo.py
+++ b/tests/test_qwen3_4B_ppo.py
@@ -122,6 +122,7 @@ def execute():
train_args=train_args,
num_gpus_per_node=NUM_GPUS,
megatron_model_type=MODEL_TYPE,
+ extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"},
)
diff --git a/tests/test_qwen3_vl_4B_fsdp.py b/tests/test_qwen3_vl_4B_fsdp.py
index fbdffd237..bc4ef3293 100644
--- a/tests/test_qwen3_vl_4B_fsdp.py
+++ b/tests/test_qwen3_vl_4B_fsdp.py
@@ -92,6 +92,7 @@ def execute():
extra_env_vars = {
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
+ "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1",
}
U.execute_train(
diff --git a/tests/utils/sglang_stub.py b/tests/utils/sglang_stub.py
new file mode 100644
index 000000000..6eece91f8
--- /dev/null
+++ b/tests/utils/sglang_stub.py
@@ -0,0 +1,44 @@
+import sys
+import types
+
+
+def _ensure_package(name: str) -> None:
+ module = sys.modules.get(name)
+ if module is None:
+ module = types.ModuleType(name)
+ module.__path__ = []
+ sys.modules[name] = module
+
+
+def install_sglang_stub() -> None:
+ _ensure_package("sglang")
+ _ensure_package("sglang.srt")
+ _ensure_package("sglang.srt.endpoints")
+ _ensure_package("sglang.srt.endpoints.openai")
+ _ensure_package("sglang.srt.entrypoints")
+ _ensure_package("sglang.srt.entrypoints.openai")
+
+ # protocol_module = types.ModuleType("sglang.srt.endpoints.openai.protocol")
+
+ class ChatCompletionMessageGenericParam:
+ def __init__(self, role: str, content: str | None = None, **kwargs):
+ self.role = role
+ self.content = content
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ def model_copy(self, update: dict):
+ data = self.__dict__.copy()
+ data.update(update)
+ return self.__class__(**data)
+
+ class ChatCompletionMessageUserParam(ChatCompletionMessageGenericParam):
+ pass
+
+ # ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam]
+
+ # protocol_module.ChatCompletionMessageGenericParam = ChatCompletionMessageGenericParam
+ # protocol_module.ChatCompletionMessageUserParam = ChatCompletionMessageUserParam
+ # protocol_module.ChatCompletionMessageParam = ChatCompletionMessageParam
+ # sys.modules["sglang.srt.endpoints.openai.protocol"] = protocol_module
+ # sys.modules["sglang.srt.entrypoints.openai.protocol"] = protocol_module
diff --git a/tests/utils/test_chat_message_utils.py b/tests/utils/test_chat_message_utils.py
new file mode 100644
index 000000000..e721af984
--- /dev/null
+++ b/tests/utils/test_chat_message_utils.py
@@ -0,0 +1,41 @@
+import pytest
+
+from miles.utils.chat_message_utils import get_think_token_end, get_think_token_start, trim_think_tokens
+
+
+def test_get_think_token_start_end():
+ assert get_think_token_start("qwen3") == ("", 151667)
+ assert get_think_token_end("qwen3") == ("", 151668)
+
+
+def test_trim_think_tokens_no_think():
+ tokens = [1, 2, 3]
+ assert trim_think_tokens(tokens, "qwen3") == tokens
+
+
+def test_trim_think_tokens_start_only():
+ tokens = [1, 151667, 2, 3]
+ assert trim_think_tokens(tokens, "qwen3") == [1]
+
+
+def test_trim_think_tokens_start_and_end():
+ tokens = [1, 151667, 2, 151668, 3]
+ assert trim_think_tokens(tokens, "qwen3") == [1]
+
+
+def test_trim_think_tokens_end_without_start():
+ tokens = [1, 151668, 2]
+ with pytest.raises(ValueError, match="No think token start found"):
+ trim_think_tokens(tokens, "qwen3")
+
+
+def test_trim_think_tokens_multiple_starts():
+ tokens = [151667, 1, 151667]
+ with pytest.raises(ValueError, match="Multiple think token start found"):
+ trim_think_tokens(tokens, "qwen3")
+
+
+def test_trim_think_tokens_multiple_ends():
+ tokens = [151667, 1, 151668, 2, 151668]
+ with pytest.raises(ValueError, match="Multiple think token end found"):
+ trim_think_tokens(tokens, "qwen3")