diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f00faa5a6..1e26168d6 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -25,6 +25,46 @@ concurrency: jobs: + fast: + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request) + runs-on: self-hosted + container: + image: radixark/miles:latest + options: > + --gpus all + --ipc=host + --shm-size=16g + --ulimit memlock=-1 + --ulimit stack=67108864 + --memory=0 + --memory-swap=0 + -v /mnt/nvme0n1/miles_ci:/data/miles_ci + -v /mnt/nvme0n1/miles_ci/models:/root/models + -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets + strategy: + fail-fast: false + matrix: + info: [{"num_gpus": 0, "test_file": "fast"}] + defaults: + run: + working-directory: ${{ github.workspace }} + env: + GITHUB_COMMIT_NAME: ${{ github.sha }}_${{ github.event.pull_request.number || 'non-pr' }} + WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} + MILES_TEST_ENABLE_INFINITE_RUN: ${{ (github.event_name == 'workflow_dispatch' && github.event.inputs.infinite_run) || 'false' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install + shell: bash + run: cd $GITHUB_WORKSPACE && pip install -e . --no-deps --break-system-packages + + - name: Execute + shell: bash + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- pytest tests/${{ matrix.info.test_file }} + e2e-test-short: if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, 'run-ci-short')) runs-on: self-hosted @@ -38,10 +78,6 @@ jobs: --ulimit stack=67108864 --memory=0 --memory-swap=0 - -e http_proxy=$http_proxy - -e https_proxy=$https_proxy - -e HTTP_PROXY=$HTTP_PROXY - -e HTTPS_PROXY=$HTTPS_PROXY -v /mnt/nvme0n1/miles_ci:/data/miles_ci -v /mnt/nvme0n1/miles_ci/models:/root/models -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets @@ -82,10 +118,6 @@ jobs: --ulimit stack=67108864 --memory=0 --memory-swap=0 - -e http_proxy=$http_proxy - -e https_proxy=$https_proxy - -e HTTP_PROXY=$HTTP_PROXY - -e HTTPS_PROXY=$HTTPS_PROXY -v /mnt/nvme0n1/miles_ci:/data/miles_ci -v /mnt/nvme0n1/miles_ci/models:/root/models -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets @@ -126,10 +158,6 @@ jobs: --ulimit stack=67108864 --memory=0 --memory-swap=0 - -e http_proxy=$http_proxy - -e https_proxy=$https_proxy - -e HTTP_PROXY=$HTTP_PROXY - -e HTTPS_PROXY=$HTTPS_PROXY -v /mnt/nvme0n1/miles_ci:/data/miles_ci -v /mnt/nvme0n1/miles_ci/models:/root/models -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets @@ -170,10 +198,6 @@ jobs: --ulimit stack=67108864 --memory=0 --memory-swap=0 - -e http_proxy=$http_proxy - -e https_proxy=$https_proxy - -e HTTP_PROXY=$HTTP_PROXY - -e HTTPS_PROXY=$HTTPS_PROXY -v /mnt/nvme0n1/miles_ci:/data/miles_ci -v /mnt/nvme0n1/miles_ci/models:/root/models -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets @@ -214,10 +238,6 @@ jobs: --ulimit stack=67108864 --memory=0 --memory-swap=0 - -e http_proxy=$http_proxy - -e https_proxy=$https_proxy - -e HTTP_PROXY=$HTTP_PROXY - -e HTTPS_PROXY=$HTTPS_PROXY -v /mnt/nvme0n1/miles_ci:/data/miles_ci -v /mnt/nvme0n1/miles_ci/models:/root/models -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets @@ -258,10 +278,6 @@ jobs: --ulimit stack=67108864 --memory=0 --memory-swap=0 - -e http_proxy=$http_proxy - -e https_proxy=$https_proxy - -e HTTP_PROXY=$HTTP_PROXY - -e HTTPS_PROXY=$HTTPS_PROXY -v /mnt/nvme0n1/miles_ci:/data/miles_ci -v /mnt/nvme0n1/miles_ci/models:/root/models -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets @@ -302,10 +318,6 @@ jobs: --ulimit stack=67108864 --memory=0 --memory-swap=0 - -e http_proxy=$http_proxy - -e https_proxy=$https_proxy - -e HTTP_PROXY=$HTTP_PROXY - -e HTTPS_PROXY=$HTTPS_PROXY -v /mnt/nvme0n1/miles_ci:/data/miles_ci -v /mnt/nvme0n1/miles_ci/models:/root/models -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 25bb2bce2..055dfee63 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -1,4 +1,10 @@ <% set jobs = { + 'fast': { + 'test_executor': 'pytest', + 'tests': [ + {'test_file': 'fast', 'num_gpus': 0}, + ], + }, 'e2e-test-short': { 'label': 'run-ci-short', 'tests': [ @@ -95,7 +101,7 @@ concurrency: jobs: <% for job_name, config in jobs.items() %> << job_name >>: - if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request && contains(github.event.pull_request.labels.*.name, '<< config.label >>')) + if: (github.event_name == 'workflow_dispatch') || (github.event.pull_request<% if config.label %> && contains(github.event.pull_request.labels.*.name, '<< config.label >>')<% endif %>) runs-on: self-hosted container: image: << config.image if config.image else 'radixark/miles:latest' >> @@ -107,10 +113,6 @@ jobs: --ulimit stack=67108864 --memory=0 --memory-swap=0 - -e http_proxy=$http_proxy - -e https_proxy=$https_proxy - -e HTTP_PROXY=$HTTP_PROXY - -e HTTPS_PROXY=$HTTPS_PROXY -v /mnt/nvme0n1/miles_ci:/data/miles_ci -v /mnt/nvme0n1/miles_ci/models:/root/models -v /mnt/nvme0n1/miles_ci/datasets:/root/datasets @@ -136,5 +138,5 @@ jobs: - name: Execute shell: bash - run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} + run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- << config.test_executor | default('python') >> tests/${{ matrix.info.test_file }} <% endfor %> \ No newline at end of file diff --git a/examples/openai_format/__init__.py b/examples/openai_format/__init__.py new file mode 100644 index 000000000..30436bcc4 --- /dev/null +++ b/examples/openai_format/__init__.py @@ -0,0 +1 @@ +"""OpenAI format examples.""" diff --git a/examples/openai_format/dapo_math.py b/examples/openai_format/dapo_math.py new file mode 100644 index 000000000..6fe69433e --- /dev/null +++ b/examples/openai_format/dapo_math.py @@ -0,0 +1,57 @@ +""" +DAPO math OpenAI format example for token in/out verification. +""" + +import argparse +from typing import Any + +from openai import AsyncOpenAI + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.openai_endpoint_utils import ( + OpenAIEndpointTracer, + compute_samples_from_openai_records, +) +from miles.rollout.generate_utils.sample_utils import merge_samples + +_DAPO_MATH_SYSTEM_PROMPT = ( + "Solve the math problem and return the final answer as \\boxed{integer}. " + "Keep the reasoning concise and finish with the boxed answer." +) + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + tracer = await OpenAIEndpointTracer.create(input.args) + messages = _normalize_prompt(input.sample.prompt) + await _run_single_turn_openai(base_url=tracer.base_url, messages=messages) + + records = await tracer.collect_records() + samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer) + if not input.args.generate_multi_samples: + samples = merge_samples(samples, input.state.tokenizer) + return GenerateFnOutput(samples=samples) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments + + +def build_dapo_math_messages(question: str) -> list[dict[str, str]]: + return [ + {"role": "system", "content": _DAPO_MATH_SYSTEM_PROMPT}, + {"role": "user", "content": question}, + ] + + +def _normalize_prompt(prompt: Any) -> list[dict[str, Any]]: + if isinstance(prompt, list): + return prompt + return build_dapo_math_messages(prompt) + + +async def _run_single_turn_openai(base_url: str, messages: list[dict[str, Any]]) -> None: + client = AsyncOpenAI(base_url=base_url, api_key="empty") + await client.chat.completions.create(model="default", messages=messages) diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 79c6649be..27211845d 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -13,8 +13,15 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from miles.backends.sglang_utils.sglang_engine import SGLangEngine -from miles.rollout.base_types import call_rollout_fn +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnTrainInput, + call_rollout_fn, +) +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function from miles.utils import tracking_utils +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.health_monitor import RolloutHealthMonitor from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client from miles.utils.iter_utils import group_by @@ -53,8 +60,14 @@ def __init__(self, args, pg): data_source_cls = load_function(self.args.data_source_path) self.data_source = data_source_cls(args) - self.generate_rollout = load_function(self.args.rollout_function_path) - self.eval_generate_rollout = load_function(self.args.eval_function_path) + self.use_experimental_refactor = enable_experimental_rollout_refactor() + if self.use_experimental_refactor: + input = RolloutFnConstructorInput(args=args, data_source=self.data_source) + self.generate_rollout = load_rollout_function(input, self.args.rollout_function_path) + self.eval_generate_rollout = load_rollout_function(input, self.args.eval_function_path) + else: + self.generate_rollout = load_function(self.args.rollout_function_path) + self.eval_generate_rollout = load_function(self.args.eval_function_path) self.custom_reward_post_process_func = None if self.args.custom_reward_post_process_path is not None: self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) @@ -142,7 +155,12 @@ def eval(self, rollout_id): return self.health_monitoring_resume() - result = call_rollout_fn(self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True) + if self.use_experimental_refactor: + result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) + else: + result = call_rollout_fn( + self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True + ) data = result.data self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) @@ -224,7 +242,12 @@ def _get_rollout_data(self, rollout_id): ) metrics = None else: - data = call_rollout_fn(self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False) + if self.use_experimental_refactor: + data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) + else: + data = call_rollout_fn( + self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False + ) metrics = data.metrics data = data.samples # flatten the data if it is a list of lists diff --git a/miles/rollout/base_types.py b/miles/rollout/base_types.py index faa85c726..c2644e87f 100644 --- a/miles/rollout/base_types.py +++ b/miles/rollout/base_types.py @@ -1,22 +1,86 @@ +from __future__ import annotations + +from argparse import Namespace from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any +from miles.rollout.data_source import DataSource from miles.utils.types import Sample +if TYPE_CHECKING: + from miles.rollout.inference_rollout.inference_rollout_common import GenerateState + + +@dataclass(frozen=True) +class RolloutFnConstructorInput: + args: Namespace + # TODO may refactor DataSource API + data_source: DataSource + + +@dataclass(frozen=True) +class RolloutFnBaseInput: + rollout_id: int + + @property + def evaluation(self): + raise NotImplementedError + + +# subclassing for different data in the future +@dataclass(frozen=True) +class RolloutFnTrainInput(RolloutFnBaseInput): + @property + def evaluation(self): + return False + +@dataclass(frozen=True) +class RolloutFnEvalInput(RolloutFnBaseInput): + @property + def evaluation(self): + return True + + +# TODO make it frozen @dataclass class RolloutFnTrainOutput: samples: list[list[Sample]] metrics: dict[str, Any] = None +# TODO make it frozen @dataclass class RolloutFnEvalOutput: data: dict[str, dict[str, Any]] metrics: dict[str, Any] = None +RolloutFnInput = RolloutFnTrainInput | RolloutFnEvalInput +RolloutFnOutput = RolloutFnTrainOutput | RolloutFnEvalOutput + + +@dataclass(frozen=True) +class GenerateFnInput: + state: GenerateState + sample: Sample + sampling_params: dict[str, Any] + evaluation: bool + + @property + def args(self) -> Namespace: + return self.state.args + + +@dataclass(frozen=True) +class GenerateFnOutput: + # One generate may lead to multiple samples, such as multi-agent, tree-like exploration, or + # multi-turn with removing thinking tokens. + samples: Sample | list[Sample] + + def call_rollout_fn(fn, *args, evaluation: bool, **kwargs): + """Legacy rollout function call interface. Used when MILES_EXPERIMENTAL_ROLLOUT_REFACTOR is disabled.""" output = fn(*args, **kwargs, evaluation=evaluation) # compatibility for legacy version diff --git a/miles/rollout/generate_hub/agentic_tool_call.py b/miles/rollout/generate_hub/agentic_tool_call.py new file mode 100644 index 000000000..05223a654 --- /dev/null +++ b/miles/rollout/generate_hub/agentic_tool_call.py @@ -0,0 +1,85 @@ +""" +Simple agentic demo with tool calling. +""" + +import argparse +from copy import deepcopy +from typing import Any + +from openai import AsyncOpenAI + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.openai_endpoint_utils import ( + OpenAIEndpointTracer, + compute_samples_from_openai_records, +) +from miles.rollout.generate_utils.sample_utils import merge_samples +from miles.rollout.generate_utils.tool_call_utils import execute_tool_calls +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + tracer = await OpenAIEndpointTracer.create(input.args) + + await _run_blackbox_tool_call_agent( + base_url=tracer.base_url, + prompt=input.sample.prompt, + max_turns=input.args.generate_max_turns, + tool_specs_path=input.args.generate_tool_specs_path, + execute_tool_function_path=input.args.generate_execute_tool_function_path, + ) + + records = await tracer.collect_records() + samples = compute_samples_from_openai_records(input.sample, records, input.state.tokenizer) + if not input.args.generate_multi_samples: + samples = merge_samples(samples, input.state.tokenizer) + return GenerateFnOutput(samples=samples) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments + + +async def _run_blackbox_tool_call_agent( + base_url: str, + prompt: list[dict[str, Any]], + max_turns: int, + tool_specs_path: str, + execute_tool_function_path: str, +): + """ + Imagine this is a black-box agent, e.g. SWE-agent, which does arbitrarily complex work, + only understands OpenAI compatible API, and never understands Miles or the Sample data structure. + """ + + # ----------------------- Setup ------------------------- + + client = AsyncOpenAI(base_url=base_url, api_key="empty") + execute_tool_function = load_function(execute_tool_function_path) + tool_specs = load_function(tool_specs_path) + + # ----------------------- Initial prompts ------------------------- + + messages = deepcopy(prompt) + + for _turn in range(max_turns): + # ----------------------- Call inference endpoint ------------------------- + + response = await client.chat.completions.create(model="default", messages=messages, tools=tool_specs) + + choice = response.choices[0] + messages.append(choice.message.model_dump()) + + if choice.finish_reason in ("stop", "length"): + break + + # ----------------------- Execute tools ------------------------- + + if x := choice.message.tool_calls: + messages += await execute_tool_calls(x, execute_tool_function) diff --git a/miles/rollout/generate_hub/multi_turn.py b/miles/rollout/generate_hub/multi_turn.py new file mode 100644 index 000000000..97814ecb3 --- /dev/null +++ b/miles/rollout/generate_hub/multi_turn.py @@ -0,0 +1,88 @@ +""" +Simple multi-turn generation with tool calling. +""" + +import argparse +from copy import deepcopy + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.generate_endpoint_utils import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.rollout.generate_utils.tool_call_utils import ( + create_tool_call_parser, + execute_tool_calls, + update_sample_with_tool_responses, +) +from miles.utils.http_utils import post +from miles.utils.misc import load_function + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + # ----------------------- Setup ------------------------- + + args = input.args + sample = deepcopy(input.sample) + tokenizer = input.state.tokenizer + assert not args.partial_rollout, "Partial rollout is not supported" + + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + execute_tool_function = load_function(args.generate_execute_tool_function_path) + + tool_specs = load_function(args.generate_tool_specs_path) + tool_call_parser = create_tool_call_parser(tool_specs, args.generate_tool_call_parser) + + multi_samples = [] + + # ----------------------- Initial prompts ------------------------- + + prompt_tokens_ids = compute_prompt_ids_from_sample(input.state, sample, tools=tool_specs) + + sample.tokens = prompt_tokens_ids.copy() + + for _turn in range(args.generate_max_turns): + # ----------------------- Call inference endpoint ------------------------- + + payload, halt_status = compute_request_payload(args, sample.tokens, input.sampling_params) + if payload is None: + sample.status = halt_status + if args.generate_multi_samples and multi_samples: + multi_samples[-1].status = halt_status + break + + if args.generate_multi_samples: + sample = deepcopy(input.sample) + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output, update_loss_mask=True) + + if args.generate_multi_samples: + multi_samples.append(deepcopy(sample)) + + if output["meta_info"]["finish_reason"]["type"] in ("abort", "length"): + break + + # ----------------------- Execute tools ------------------------- + + _, tool_calls = tool_call_parser.parse_non_stream(output["text"]) + if len(tool_calls) == 0: + break + + tool_messages = await execute_tool_calls(tool_calls, execute_tool_function) + update_sample_with_tool_responses(sample, tool_messages, tokenizer=tokenizer) + + return GenerateFnOutput(samples=multi_samples if args.generate_multi_samples else sample) + + +def _add_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--generate-max-turns", type=int, default=16) + parser.add_argument("--generate-tool-specs-path", type=str) + parser.add_argument("--generate-tool-call-parser", type=str) + parser.add_argument("--generate-execute-tool-function-path", type=str) + parser.add_argument("--generate-multi-samples", action="store_true") + + +generate.add_arguments = _add_arguments diff --git a/miles/rollout/generate_hub/single_turn.py b/miles/rollout/generate_hub/single_turn.py new file mode 100644 index 000000000..5c0a15b5b --- /dev/null +++ b/miles/rollout/generate_hub/single_turn.py @@ -0,0 +1,46 @@ +""" +Simple single-turn generation. +""" + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.generate_endpoint_utils import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.utils.http_utils import post +from miles.utils.types import Sample + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + args = input.args + sample = input.sample + sampling_params = input.sampling_params + assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}" + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + prompt_ids = compute_prompt_ids_from_sample(input.state, sample) + + # Handle Partial Rollout resuming + if len(sample.response) > 0: + input_ids = sample.tokens + sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) + + assert sampling_params["max_new_tokens"] >= 0 + if sampling_params["max_new_tokens"] == 0: + sample.status = Sample.Status.TRUNCATED + return GenerateFnOutput(samples=sample) + else: + input_ids = prompt_ids + + payload, halt_status = compute_request_payload( + args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs + ) + if payload is None: + sample.status = halt_status + return GenerateFnOutput(samples=sample) + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output) + + return GenerateFnOutput(samples=sample) diff --git a/miles/rollout/generate_utils/__init__.py b/miles/rollout/generate_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/miles/rollout/generate_utils/generate_endpoint_utils.py b/miles/rollout/generate_utils/generate_endpoint_utils.py new file mode 100644 index 000000000..a91d71f1d --- /dev/null +++ b/miles/rollout/generate_utils/generate_endpoint_utils.py @@ -0,0 +1,112 @@ +""" +Utils to integrate SGLang's `/generate` endpoint with RL things like Sample. +""" + +from copy import deepcopy +from typing import Any + +import numpy as np +import pybase64 + +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.types import Sample + + +# Make this an isolated function because users may want to compute their own +def compute_prompt_ids_from_sample(state, sample, tools=None): + prompt = sample.prompt + + if state.processor: + processor_output = state.processor(text=prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + + # TODO shall we move it to other places? then can make this function immutable + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + + return prompt_ids + else: + if not isinstance(prompt, str): + prompt = state.tokenizer.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True, tools=tools + ) + + return state.tokenizer.encode(prompt, add_special_tokens=False) + + +def compute_request_payload( + args, + input_ids: list[int], + sampling_params: dict, + multimodal_inputs: dict | None = None, +) -> tuple[dict[str, Any] | None, Sample.Status | None]: + sampling_params = deepcopy(sampling_params) + max_new_tokens = sampling_params.pop("max_new_tokens", args.rollout_max_response_len) + if x := args.rollout_max_context_len: + max_new_tokens = min(max_new_tokens, x - len(input_ids)) + if max_new_tokens <= 0: + return None, Sample.Status.TRUNCATED + + payload = { + "input_ids": input_ids, + "sampling_params": {**sampling_params, "max_new_tokens": max_new_tokens}, + "return_logprob": True, + "return_routed_experts": args.use_rollout_routing_replay, + } + if image_data := (multimodal_inputs or {}).get("images"): + payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + + return payload, None + + +async def update_sample_from_response( + args, sample: Sample, payload: dict, output: dict, update_loss_mask: bool = False +): + # Initialize sample.tokens for the first turn + if (len(sample.response) == 0) and not sample.tokens: + sample.tokens = payload["input_ids"] + + if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: + from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + + # TODO may rename to match + await postprocess_sample_with_radix_tree(args, sample, output) + + assert not update_loss_mask, "This code branch has not implemented update_loss_mask" + else: + if x := output["meta_info"].get("output_token_logprobs"): + new_response_tokens = [item[1] for item in x] + new_response_log_probs = [item[0] for item in x] + else: + new_response_tokens, new_response_log_probs = [], [] + + # Update sample with tokens directly - avoiding re-tokenization + sample.tokens = sample.tokens + new_response_tokens + sample.response_length += len(new_response_tokens) + sample.response += output["text"] + + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += new_response_log_probs + + if update_loss_mask: + if sample.loss_mask is None: + sample.loss_mask = [] + sample.loss_mask += [1] * len(new_response_tokens) + + # TODO handle multi-turn cases (may need concat instead of assignment) + sample.rollout_routed_experts = _get_rollout_routed_experts_from_response(args, sample, output) + + # TODO may unify (currently there are both methods inside Sample and separate functions) + sample.update_from_meta_info(args, output["meta_info"]) + + +def _get_rollout_routed_experts_from_response(args, sample, output): + info = output["meta_info"].get("routed_experts") + if info is None: + return None + + x = np.frombuffer(pybase64.b64decode(info.encode("ascii")), dtype=np.int32) + x = x.reshape(len(sample.tokens) - 1, args.num_layers, args.moe_router_topk) + return x diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py new file mode 100644 index 000000000..a7a3a3e4a --- /dev/null +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -0,0 +1,69 @@ +""" +Utilities for the OpenAI endpoint +""" + +import logging +from argparse import Namespace +from copy import deepcopy + +from miles.router.session.sessions import GetSessionResponse, SessionRecord +from miles.utils.http_utils import post +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class OpenAIEndpointTracer: + def __init__(self, router_url: str, session_id: str): + self.router_url = router_url + self.session_id = session_id + self.base_url = f"{router_url}/sessions/{session_id}/v1" + + @staticmethod + async def create(args: Namespace): + router_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" + session_id = (await post(f"{router_url}/sessions", {}))["session_id"] + return OpenAIEndpointTracer(router_url=router_url, session_id=session_id) + + async def collect_records(self) -> list[SessionRecord]: + response = await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="get") + response = GetSessionResponse.model_validate(response) + records = response.session_records + if records is None and isinstance(response.records, list): + records = response.records + + try: + await post(f"{self.router_url}/sessions/{self.session_id}", {}, action="delete") + except Exception as e: + logger.warning(f"Failed to delete session {self.session_id} after collecting records: {e}") + + return records or [] + + +def compute_samples_from_openai_records(input_sample: Sample, records: list[SessionRecord], tokenizer) -> list[Sample]: + return [_compute_sample_from_openai_record(input_sample, record, tokenizer) for record in records] + + +def _compute_sample_from_openai_record(input_sample: Sample, record: SessionRecord, tokenizer) -> Sample: + # TODO may refine after @guapisolo's implementation + choice = record.response["choices"][0] + output_token_ids = [item["token_id"] for item in choice["logprobs"]["content"]] + output_log_probs = [item["logprob"] for item in choice["logprobs"]["content"]] + + sample = deepcopy(input_sample) + sample.tokens = record.request["input_ids"] + output_token_ids + sample.rollout_log_probs = output_log_probs + sample.response = tokenizer.decode(output_token_ids) + sample.response_length = len(output_token_ids) + sample.loss_mask = [1] * len(output_token_ids) + + # TODO unify with Sample.update_from_meta_info + match choice["finish_reason"]: + case "stop" | "tool_calls": + sample.status = Sample.Status.COMPLETED + case "length": + sample.status = Sample.Status.TRUNCATED + case "abort": + sample.status = Sample.Status.ABORTED + + return sample diff --git a/miles/rollout/generate_utils/sample_utils.py b/miles/rollout/generate_utils/sample_utils.py new file mode 100644 index 000000000..6a4e645be --- /dev/null +++ b/miles/rollout/generate_utils/sample_utils.py @@ -0,0 +1,115 @@ +from copy import deepcopy +from dataclasses import fields + +from miles.utils.types import Sample + + +def merge_samples(samples: list[Sample], tokenizer) -> Sample: + acc = samples[0] + for sample in samples[1:]: + acc = _merge_sample_pair(acc, sample, tokenizer=tokenizer) + return acc + + +def _merge_sample_pair(a: Sample, b: Sample, tokenizer) -> Sample: + """Merge two samples generated from sibling inference engine calls.""" + a, b = deepcopy(a), deepcopy(b) + + def _merge_equal_value(field): + x = getattr(a, field) + y = getattr(b, field) + assert x == y, f"{field} mismatch: a.{field}={x}, b.{field}={y}" + return x + + def _fill_defaults(sample: Sample): + if sample.loss_mask is None: + sample.loss_mask = [1] * sample.response_length + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [0.0] * sample.response_length + + _fill_defaults(a) + _fill_defaults(b) + + obs_len = len(b.tokens) - len(a.tokens) - b.response_length + obs_tokens = b.tokens[len(a.tokens) : len(a.tokens) + obs_len] + # TODO: is this acceptable? + obs_text = tokenizer.decode(obs_tokens) + + try: + a.validate() + b.validate() + assert _startswith(short=a.prompt, long=b.prompt), "b.prompt must start with a.prompt" + assert _startswith(short=a.tokens, long=b.tokens), "b.tokens must start with a.tokens" + assert obs_len > 0, f"obs_len must be > 0, got {obs_len}" + if a.rollout_routed_experts is not None: + assert a.rollout_routed_experts.shape[0] <= b.rollout_routed_experts.shape[0] + assert a.status == Sample.Status.COMPLETED, f"a.status must be COMPLETED, got {a.status}" + + return _create_with_all_fields( + Sample, + group_index=_merge_equal_value("group_index"), + index=_merge_equal_value("index"), + prompt=b.prompt, + tokens=b.tokens, + multimodal_inputs=_merge_equal_value("multimodal_inputs"), + multimodal_train_inputs=_merge_equal_value("multimodal_train_inputs"), + response=a.response + obs_text + b.response, + response_length=a.response_length + obs_len + b.response_length, + label=_merge_equal_value("label"), + reward=_merge_equal_value("reward"), + loss_mask=a.loss_mask + [0] * obs_len + b.loss_mask, + weight_versions=a.weight_versions + b.weight_versions, + rollout_log_probs=a.rollout_log_probs + [0.0] * obs_len + b.rollout_log_probs, + rollout_routed_experts=b.rollout_routed_experts, + remove_sample=_merge_equal_value("remove_sample"), + status=b.status, + metadata=_merge_equal_value("metadata"), + train_metadata=_merge_equal_value("train_metadata"), + non_generation_time=_merge_equal_value("non_generation_time"), + spec_info=_merge_spec_info(a.spec_info, b.spec_info), + prefix_cache_info=_merge_prefix_cache_info(a.prefix_cache_info, b.prefix_cache_info), + ) + except AssertionError as e: + e.add_note(f"{a=} {b=}") + raise + + +def _merge_spec_info(a: Sample.SpecInfo, b: Sample.SpecInfo) -> Sample.SpecInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.SpecInfo, + spec_accept_token_num=_merge_plus_value("spec_accept_token_num"), + spec_draft_token_num=_merge_plus_value("spec_draft_token_num"), + spec_verify_ct=_merge_plus_value("spec_verify_ct"), + completion_token_num=_merge_plus_value("completion_token_num"), + ) + + +def _merge_prefix_cache_info(a: Sample.PrefixCacheInfo, b: Sample.PrefixCacheInfo) -> Sample.PrefixCacheInfo: + def _merge_plus_value(field): + return getattr(a, field) + getattr(b, field) + + return _create_with_all_fields( + Sample.PrefixCacheInfo, + cached_tokens=_merge_plus_value("cached_tokens"), + total_prompt_tokens=_merge_plus_value("total_prompt_tokens"), + ) + + +def _create_with_all_fields(cls, **kwargs): + expected = {f.name for f in fields(cls)} + actual = set(kwargs.keys()) + assert ( + expected == actual + ), f"{cls.__name__} field mismatch. Missing: {expected - actual}, Extra: {actual - expected}" + return cls(**kwargs) + + +def _startswith(*, short, long) -> bool: + if isinstance(short, str) and isinstance(long, str): + return long.startswith(short) + if isinstance(short, list) and isinstance(long, list): + return (len(long) >= len(short)) and (long[: len(short)] == short) + raise NotImplementedError diff --git a/miles/rollout/generate_utils/tokenize_utils.py b/miles/rollout/generate_utils/tokenize_utils.py new file mode 100644 index 000000000..ca4f59938 --- /dev/null +++ b/miles/rollout/generate_utils/tokenize_utils.py @@ -0,0 +1,52 @@ +from typing import Any +from transformers import AutoTokenizer +from miles.rollout.generate_utils.tool_call_utils import tokenize_tool_responses + + +_DUMMY_SYSTEM = {"role": "system", "content": "FOR CALCULATING ADDITIONAL TOKENS ONLY"} +_DUMMY_USER = {"role": "user", "content": "FOR CALCULATING ADDITIONAL TOKENS ONLY"} +_DUMMY_ASSISTANT = {"role": "assistant", "content": "FOR CALCULATING ADDITIONAL TOKENS ONLY"} + + +def calc_generation_prompt_tokens(tokenizer: AutoTokenizer) -> list[int]: + messages = [_DUMMY_SYSTEM, _DUMMY_USER, _DUMMY_ASSISTANT] + with_generation_prompt = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + without_generation_prompt = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False) + assert with_generation_prompt[: len(without_generation_prompt)] == without_generation_prompt + return with_generation_prompt[len(without_generation_prompt) :] + + +# TODO(jiajun): need e2e test to validate. According to https://zhuanlan.zhihu.com/p/1917126584806139373 +# Notice: This function will automatically trim think tokens if the model's chat template trim thinking parts. Like Qwen3. +def _naive_calc_additional_tokens( + message: dict[str, Any], tokenizer: AutoTokenizer, add_generation_prompt: bool = True +) -> list[int]: + prefix = [_DUMMY_SYSTEM, _DUMMY_USER, _DUMMY_ASSISTANT, _DUMMY_USER] + suffix = [_DUMMY_SYSTEM, _DUMMY_USER] + prefix_tokens = tokenizer.apply_chat_template(prefix, tokenize=True, add_special_tokens=False) + messages_tokens = tokenizer.apply_chat_template( + prefix + [message] + suffix, tokenize=True, add_special_tokens=False + ) + suffix_tokens = tokenizer.apply_chat_template(suffix, tokenize=True, add_special_tokens=False) + + response_tokens = messages_tokens[len(prefix_tokens) : -len(suffix_tokens)] + generation_prompt_tokens = calc_generation_prompt_tokens(tokenizer) + return response_tokens + generation_prompt_tokens + + +# TODO(jiajun): need e2e test to validate. +def tokenize_messages( + messages: list[dict[str, Any]], + tokenizer, + add_generation_prompt: bool = True, +) -> list[int]: + token_ids = [] + for message in messages: + if message["role"] == "assistant" or message["role"] == "user" or message["role"] == "system": + token_ids.extend(_naive_calc_additional_tokens(message, tokenizer, add_generation_prompt)) + elif message["role"] == "tool": + token_ids.extend(tokenize_tool_responses([message], tokenizer)) + else: + raise ValueError(f"Unsupported message role: {message['role']}") + + return token_ids diff --git a/miles/rollout/generate_utils/tool_call_utils.py b/miles/rollout/generate_utils/tool_call_utils.py new file mode 100644 index 000000000..85ea87aea --- /dev/null +++ b/miles/rollout/generate_utils/tool_call_utils.py @@ -0,0 +1,115 @@ +""" +Utils to handle tool calls. +""" + +import json +import uuid +from collections.abc import Callable +from typing import Any + +from openai.types.chat import ChatCompletionMessageToolCall +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.types import Sample + +_DUMMY_USER = {"role": "user", "content": "dummy"} + + +def create_tool_call_parser(tool_specs, tool_call_parser): + return FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tool_specs), + tool_call_parser=tool_call_parser, + ) + + +async def execute_tool_calls( + tool_calls: list[ToolCallItem | ChatCompletionMessageToolCall], + execute_one: Callable, +) -> list[dict[str, Any]]: + tool_messages = [] + for call in tool_calls: + tool_messages.append(await _execute_tool_call(call, execute_one)) + return tool_messages + + +async def _execute_tool_call( + call: ToolCallItem | ChatCompletionMessageToolCall, execute_one: Callable +) -> dict[str, Any]: + if isinstance(call, ChatCompletionMessageToolCall): + name = call.function.name + params = json.loads(call.function.arguments) if call.function.arguments else {} + tool_call_id = call.id + elif isinstance(call, ToolCallItem): + name = call.name + params = json.loads(call.parameters) if call.parameters else {} + tool_call_id = f"call_{uuid.uuid4().hex[:24]}" + else: + raise TypeError(f"Unsupported tool call type: {type(call)}") + + result = await execute_one(name, params) + assert isinstance(result, str) + + return {"role": "tool", "tool_call_id": tool_call_id, "content": result, "name": name} + + +def update_sample_with_tool_responses(sample: Sample, tool_messages: list[dict[str, Any]], tokenizer): + next_obs_tokens_ids: list[int] = tokenize_tool_responses(tool_messages, tokenizer=tokenizer) + sample.response += tokenizer.decode(next_obs_tokens_ids) + sample.response_length += len(next_obs_tokens_ids) + sample.tokens += next_obs_tokens_ids + sample.loss_mask += [0] * len(next_obs_tokens_ids) + sample.rollout_log_probs += [0.0] * len(next_obs_tokens_ids) + + +# TODO: very naive implementation, need the to-be-implemented e2e test to validate. +def tokenize_tool_responses( + tool_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + return _tokenize_postfix_messages(tool_messages, tokenizer) + + +def _tokenize_postfix_messages( + postfix_messages: list[dict[str, Any]], + tokenizer, +) -> list[int]: + dummy_assistant = _build_dummy_assistant(postfix_messages) + base_messages = [_DUMMY_USER, dummy_assistant] + + messages_without = base_messages + messages_with = base_messages + postfix_messages + + tokens_with = tokenizer.apply_chat_template(messages_with, tokenize=True, add_generation_prompt=True) + tokens_without = tokenizer.apply_chat_template(messages_without, tokenize=True, add_generation_prompt=False) + + assert tokens_with[: len(tokens_without)] == tokens_without, ( + f"Fail to tokenize_tool_responses caused by token prefix mismatch. " + f"This can happen for thinking model or models with special chat template, " + f"and this simple example does not support it yet, " + f"since this means we cannot have a append-only token id list. " + f"{tokens_with=} {tokens_without=} " + f"{tokenizer.decode(tokens_with)=} {tokenizer.decode(tokens_without)=} " + ) + return tokens_with[len(tokens_without) :] + + +def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]: + return { + "role": "assistant", + "content": "", + "reasoning_content": " ", + "tool_calls": [ + { + "id": resp.get("tool_call_id", f"call0000{i}"), + "type": "function", + "function": { + "name": resp.get("name", "dummy_func"), + "arguments": {}, + }, + } + for i, resp in enumerate(tool_responses) + ], + } diff --git a/miles/rollout/inference_rollout/__init__.py b/miles/rollout/inference_rollout/__init__.py new file mode 100644 index 000000000..33ccf17bf --- /dev/null +++ b/miles/rollout/inference_rollout/__init__.py @@ -0,0 +1,2 @@ +# This is a refactor of the portions above generate-function in sglang_rollout.py, +# and is give a different name to ensure both code exist at the same time. diff --git a/miles/rollout/inference_rollout/compatibility.py b/miles/rollout/inference_rollout/compatibility.py new file mode 100644 index 000000000..7711e0dd3 --- /dev/null +++ b/miles/rollout/inference_rollout/compatibility.py @@ -0,0 +1,84 @@ +import inspect +from collections.abc import Callable + +from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, + RolloutFnConstructorInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnTrainOutput, +) +from miles.utils.async_utils import run +from miles.utils.misc import load_function + + +class LegacyRolloutFnAdapter: + def __init__(self, input: RolloutFnConstructorInput, fn: Callable): + self.args = input.args + self.data_source = input.data_source + self.fn = fn + + def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + output = self.fn(self.args, input.rollout_id, self.data_source, evaluation=input.evaluation) + + # compatibility for legacy version + if not isinstance(output, (RolloutFnTrainOutput, RolloutFnEvalOutput)): + output = RolloutFnEvalOutput(data=output) if input.evaluation else RolloutFnTrainOutput(samples=output) + + return output + + +def load_rollout_function(input: RolloutFnConstructorInput, path: str): + fn = load_function(path) + + if inspect.isclass(fn): + return fn(input) + else: + return LegacyRolloutFnAdapter(input, fn) + + +def call_rollout_function(fn, input: RolloutFnInput) -> RolloutFnOutput: + output = fn(input) + + if inspect.iscoroutine(output): + output = run(output) + + return output + + +class LegacyGenerateFnAdapter: + def __init__(self, fn: Callable): + self.fn = fn + self._has_evaluation_param = "evaluation" in inspect.signature(fn).parameters + + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + if self._has_evaluation_param: + output = await self.fn(input.args, input.sample, input.sampling_params, evaluation=input.evaluation) + else: + output = await self.fn(input.args, input.sample, input.sampling_params) + + if not isinstance(output, GenerateFnOutput): + output = GenerateFnOutput(samples=output) + + return output + + +def load_generate_function(path: str): + fn = load_function(path) + if fn is None: + return None + + if inspect.isclass(fn): + return fn() + elif _is_legacy_generate_fn(fn): + return LegacyGenerateFnAdapter(fn) + else: + return fn + + +def _is_legacy_generate_fn(fn: Callable) -> bool: + sig = inspect.signature(fn) + params = list(sig.parameters.keys()) + return len(params) >= 3 and params[0] != "input" diff --git a/miles/rollout/inference_rollout/inference_rollout_common.py b/miles/rollout/inference_rollout/inference_rollout_common.py new file mode 100644 index 000000000..8518c6e02 --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_common.py @@ -0,0 +1,192 @@ +import asyncio +import logging +from argparse import Namespace +from copy import deepcopy +from typing import Any + +from miles.rollout.base_types import ( + GenerateFnInput, + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnInput, + RolloutFnOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) +from miles.rollout.generate_hub.single_turn import generate +from miles.rollout.inference_rollout.compatibility import load_generate_function +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class GenerateState: + def __init__(self, args: Namespace) -> None: + # persistent state for the generation process + self.args = args + self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + + self.generate_fn_semaphore = asyncio.Semaphore( + args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + ) + self.sampling_params: dict[str, Any] = compute_sampling_params( + args, + temperature=args.rollout_temperature, + top_p=args.rollout_top_p, + top_k=args.rollout_top_k, + max_new_tokens=args.rollout_max_response_len, + ) + + self.generate_function = load_generate_function(args.custom_generate_function_path) or generate + + self.reset() + + def reset(self) -> None: + self.aborted = False + + +async def generate_and_rm( + state: GenerateState, + sample: Sample | list[Sample], + sampling_params: dict[str, Any], + evaluation: bool = False, +) -> Sample | list[Sample]: + args = state.args + + # mask previous off-policy generation for partial rollout + if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: + sample.loss_mask = [0] * sample.response_length + + # For samples with existing response, check if they're complete + if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED: + assert sample.response is not None + if not args.group_rm: + assert sample.reward is not None + return sample + + # generate + async with state.generate_fn_semaphore: + if state.aborted: + sample.status = Sample.Status.ABORTED + return sample + + output = await state.generate_function( + GenerateFnInput( + state=state, + sample=sample, + sampling_params=deepcopy(sampling_params), + evaluation=evaluation, + ) + ) + sample = output.samples + + # TODO change to `if not args.group_rm: do reward model` for more clarity after the refactor below + # for the rm that need the whole group, we will not do the rm here + if args.group_rm: + return sample + + # TODO: unify the two branches into one if we decide to use list as output type + # multi samples + if isinstance(sample, list): + samples = sample + if any([sample.status == Sample.Status.ABORTED for sample in samples]): + return samples + + # for multi agent system, the reward of some sample is calculated during generation. + samples_need_reward = [sample for sample in samples if sample.reward is None] + await batched_async_rm(args, samples_need_reward, inplace_set_reward_field=True) + return samples + else: + if sample.status == Sample.Status.ABORTED: + return sample + # for multi-turn environment, a reward could be assigned to the agent. + if sample.reward is None: + sample.reward = await async_rm(args, sample) + + return sample + + +async def generate_and_rm_group( + state: GenerateState, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False +) -> list[Sample]: + args = state.args + + if state.aborted: + return group + + tasks = [] + for idx, sample in enumerate(group): + current_sampling_params = sampling_params.copy() + if getattr(args, "sglang_enable_deterministic_inference", False): + current_sampling_params["sampling_seed"] = args.rollout_seed + idx + tasks.append( + asyncio.create_task(generate_and_rm(state, sample, current_sampling_params, evaluation=evaluation)) + ) + + group = await asyncio.gather(*tasks) + if state.aborted: + return group + + if args.group_rm: + await batched_async_rm(args, group, inplace_set_reward_field=True) + + return group + + +def compute_sampling_params( + args, + *, + # after unifying configuration, this can be further refactored + temperature, + top_p, + top_k, + max_new_tokens, +): + return dict( + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_new_tokens=max_new_tokens, + stop=args.rollout_stop, + stop_token_ids=args.rollout_stop_token_ids, + skip_special_tokens=args.rollout_skip_special_tokens, + no_stop_trim=True, + spaces_between_special_tokens=False, + ) + + +class InferenceRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + self.data_source = input.data_source + self.state = GenerateState(input.args) + self.eval_prompt_dataset_cache = {} + + async def __call__(self, input: RolloutFnInput) -> RolloutFnOutput: + if input.evaluation: + return await self._call_eval(input) + return await self._call_train(input) + + async def _call_train(self, input: RolloutFnTrainInput) -> RolloutFnTrainOutput: + from miles.rollout.inference_rollout.inference_rollout_train import generate_rollout_async + + output, aborted_samples = await generate_rollout_async( + self.state, input.rollout_id, self.data_source.get_samples + ) + self.data_source.add_samples(aborted_samples) + return output + + async def _call_eval(self, input: RolloutFnEvalInput) -> RolloutFnEvalOutput: + from miles.rollout.inference_rollout.inference_rollout_eval import eval_rollout_single_dataset + + assert not self.state.args.group_rm, "Group RM is not supported for eval rollout" + + coros = [] + for dataset_cfg in getattr(self.state.args, "eval_datasets", []) or []: + coros.append(eval_rollout_single_dataset(self.state, dataset_cfg, self.eval_prompt_dataset_cache)) + results_list = await asyncio.gather(*coros) + results = {k: v for r in results_list for k, v in r.items()} + return RolloutFnEvalOutput(data=results) diff --git a/miles/rollout/inference_rollout/inference_rollout_eval.py b/miles/rollout/inference_rollout/inference_rollout_eval.py new file mode 100644 index 000000000..2d052be0a --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_eval.py @@ -0,0 +1,112 @@ +import asyncio +import copy +import logging +from typing import Any + +from tqdm import tqdm + +from miles.rollout.inference_rollout.inference_rollout_common import ( + GenerateState, + compute_sampling_params, + generate_and_rm, +) +from miles.utils.data import Dataset +from miles.utils.eval_config import EvalDatasetConfig +from miles.utils.misc import as_completed_async +from miles.utils.processing_utils import load_processor, load_tokenizer +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +async def eval_rollout_single_dataset( + state: GenerateState, + dataset_cfg: EvalDatasetConfig, + prompt_dataset_cache: dict[Any, Dataset], +) -> dict[str, dict[str, list[Any]]]: + args = state.args + assert not args.group_rm, "Group RM is not supported for eval rollout" + + cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template) + if cache_key not in prompt_dataset_cache: + tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + prompt_dataset_cache[cache_key] = Dataset( + path=dataset_cfg.path, + tokenizer=tokenizer, + processor=processor, + max_length=args.eval_max_prompt_len, + prompt_key=dataset_cfg.input_key, + label_key=dataset_cfg.label_key, + multimodal_keys=args.multimodal_keys, + metadata_key=dataset_cfg.metadata_key, + tool_key=dataset_cfg.tool_key, + apply_chat_template=args.apply_chat_template, + apply_chat_template_kwargs=args.apply_chat_template_kwargs, + ) + dataset = prompt_dataset_cache[cache_key] + + base_sampling_params = compute_sampling_params( + args, + temperature=dataset_cfg.temperature, + top_p=dataset_cfg.top_p, + top_k=dataset_cfg.top_k, + max_new_tokens=dataset_cfg.max_response_len, + ) + + tasks = [] + # do multiple samples for eval prompts + sample_index = 0 + for _i, prompt_sample in enumerate(dataset.samples): + for j in range(dataset_cfg.n_samples_per_eval_prompt): + # use the same prompt for multiple samples + sample = copy.deepcopy(prompt_sample) + sample.index = sample_index + sample_index += 1 + sample.metadata = dataset_cfg.inject_metadata(getattr(sample, "metadata", None)) + sampling_params = base_sampling_params + if getattr(args, "sglang_enable_deterministic_inference", False): + sampling_params = base_sampling_params.copy() + sampling_params["sampling_seed"] = args.rollout_seed + j + tasks.append( + asyncio.create_task( + generate_and_rm( + state, + sample, + sampling_params=sampling_params, + evaluation=True, + ) + ) + ) + + data = [] + do_print = True + pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) + async for sample in as_completed_async(tasks): + if do_print: + # TODO improve this after enhancing samples' type + s = (sample[0] if len(sample) > 0 else None) if isinstance(sample, list) else sample + if s is not None: + logger.info( + "eval_rollout_single_dataset example data: " + f"{[str(s.prompt) + s.response]} " + f"reward={s.reward}" + ) + do_print = False + if isinstance(sample, list): + data.extend(sample) + else: + data.append(sample) + pbar.update(1) + pbar.close() + + data.sort(key=lambda sample: sample.index) + + reward_key = args.eval_reward_key or args.reward_key + return { + dataset_cfg.name: { + "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data], + "truncated": [sample.status == Sample.Status.TRUNCATED for sample in data], + "samples": data, + } + } diff --git a/miles/rollout/inference_rollout/inference_rollout_train.py b/miles/rollout/inference_rollout/inference_rollout_train.py new file mode 100644 index 000000000..bae94ec67 --- /dev/null +++ b/miles/rollout/inference_rollout/inference_rollout_train.py @@ -0,0 +1,146 @@ +import asyncio +import logging +from argparse import Namespace +from collections.abc import Callable + +import sglang_router +from packaging.version import parse +from tqdm import tqdm + +from miles.rollout.base_types import RolloutFnTrainOutput +from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState, generate_and_rm_group +from miles.utils.http_utils import get, post +from miles.utils.misc import as_completed_async, load_function +from miles.utils.types import Sample + +logger = logging.getLogger(__name__) + + +async def abort(state: GenerateState, pendings: set, rollout_id: int) -> list[list[Sample]]: + args = state.args + + assert not state.aborted + state.aborted = True + + urls = await get_worker_urls(args) + logger.info(f"Abort request for {urls}") + await asyncio.gather(*[post(f"{url}/abort_request", {"abort_all": True}) for url in urls]) + + # make sure all the pending tasks are finished + aborted_samples = [] + async for group in as_completed_async(pendings): + if not args.partial_rollout: + continue + + # for partial rollout, collect the partial samples into the data buffer + for sample in group: + if sample.response and "start_rollout_id" not in sample.metadata: + sample.metadata["start_rollout_id"] = rollout_id + aborted_samples.append(group) + + if args.partial_rollout: + logger.info(f"Collected {sum(len(x) for x in aborted_samples)} partial samples into the data buffer") + + return aborted_samples + + +async def get_worker_urls(args: Namespace): + if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + return response["urls"] + else: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + return [worker["url"] for worker in response["workers"]] + + +def submit_generate_tasks(state: GenerateState, samples: list[list[Sample]]): + return [ + asyncio.create_task( + # submit a group of samples as a single task. + generate_and_rm_group( + state, + group, + sampling_params=state.sampling_params.copy(), + evaluation=False, + ) + ) + for group in samples + ] + + +async def generate_rollout_async( + state: GenerateState, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] +) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: + args = state.args + assert args.rollout_global_dataset + + # instantiate data filters + dynamic_filter = load_function(args.dynamic_sampling_filter_path) + + metric_gatherer = MetricGatherer() + + # target_data_size is the total number of valid samples to get + target_data_size = args.rollout_batch_size + + pendings = set() + data = [] + all_data = [] + do_print = True + pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") + while len(data) < target_data_size: + while len(data) + len(pendings) < target_data_size: + # get samples from the buffer and submit the generation requests. + samples = data_source(args.over_sampling_batch_size) + pendings.update(submit_generate_tasks(state, samples)) + + # wait for the generation to finish + done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED) + for task in done: + group: list[Sample] = task.result() + + if do_print: + sample = group[0][0] if isinstance(group[0], list) else group[0] + logger.info( + f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + do_print = False + + assert len(group) == args.n_samples_per_prompt + all_data.append(group) + dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) + if not dynamic_filter_output.keep: + metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) + continue + + # add the samples to the data + # NOTE: here we have not stored all the unused samples back to the data buffer. + if len(data) < target_data_size: + data.append(group) + pbar.update(args.n_samples_per_prompt) + + pbar.close() + sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] + logger.info( + f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}", + ) + + # there are still some unfinished requests, abort them + aborted_samples = await abort(state, pendings, rollout_id) + + assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" + data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) + all_samples = sorted( + all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index + ) + + # reset the global state to prevent effects on the next rollout or eval. + state.reset() + + if f := load_function(args.rollout_sample_filter_path): + f(args, data) + # There can be circumstances where users want to process all samples including filtered ones. + if f := load_function(args.rollout_all_samples_process_path): + f(args, all_samples, data_source) + + return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples diff --git a/miles/rollout/rm_hub/__init__.py b/miles/rollout/rm_hub/__init__.py index 62b253dde..e9ee29db4 100644 --- a/miles/rollout/rm_hub/__init__.py +++ b/miles/rollout/rm_hub/__init__.py @@ -69,8 +69,18 @@ async def async_rm(args, sample: Sample, **kwargs): async def batched_async_rm( args, samples: list[Sample], + inplace_set_reward_field: bool = False, **kwargs, -) -> list[int | float]: +) -> list[int | float] | None: + if inplace_set_reward_field: + rewards = await batched_async_rm(args, samples, **kwargs) + for sample, reward in zip(samples, rewards, strict=True): + assert ( + sample.reward is None + ), f"Overriding sample.reward from {sample.reward} to {reward}, is this intended?" + sample.reward = reward + return None + if args.custom_rm_path is not None: # Ensure the custom reward function is implemented in batch mode rm_function = load_function(args.custom_rm_path) diff --git a/miles/router/middleware_hub/radix_tree.py b/miles/router/middleware_hub/radix_tree.py index 6e722f1e2..67b9d6fe4 100644 --- a/miles/router/middleware_hub/radix_tree.py +++ b/miles/router/middleware_hub/radix_tree.py @@ -584,8 +584,8 @@ def retrieve_from_text(self, text: str, return_logprob: bool = True): text: Input text to get tokens for return_logprob: If True, also return log probabilities Returns: - List of token IDs corresponding to the input text if return_logprob is False. - Tuple of (token_ids, logp) if return_logprob is True. + List of token (IDs, logp, loss_mask) corresponding to the input text + if return_logprob is False, all logp will be 0.0 """ # Call find_longest_prefix to get the match result result = self.find_longest_prefix(text) diff --git a/miles/router/middleware_hub/radix_tree_middleware.py b/miles/router/middleware_hub/radix_tree_middleware.py index db57f6456..b9d62d841 100644 --- a/miles/router/middleware_hub/radix_tree_middleware.py +++ b/miles/router/middleware_hub/radix_tree_middleware.py @@ -66,12 +66,14 @@ def __init__(self, app, *, router): self.router.radix_tree = self.radix_tree async def dispatch(self, request: Request, call_next): - path = request.url.path + if path == "/generate": + return await self._generate(request, call_next) + if path == "/retrieve_from_text": + return await self._retrieve_from_text(request) + return await call_next(request) - if path != "/generate": - return await call_next(request) - + async def _generate(self, request: Request, call_next): request_json = await request.json() if "text" in request_json: input_text = request_json.pop("text", "") @@ -154,6 +156,23 @@ async def dispatch(self, request: Request, call_next): print(f"[miles-router] Warning: Failed to cache trajectory: {e}") return response + async def _retrieve_from_text(self, request: Request): + payload = await request.json() + text = payload.get("text", "") + token_ids, logp, loss_mask = self.radix_tree.retrieve_from_text(text, return_logprob=True) + result = { + "response": text, + "tokens": token_ids, + "loss_mask": loss_mask, + "rollout_logp": logp, + "token_length": len(token_ids), + "loss_mask_length": len(loss_mask), + } + assert ( + len(token_ids) == len(logp) == len(loss_mask) + ), "Token IDs, logp, and loss mask must have the same length" + return JSONResponse(result) + async def postprocess_sample_with_radix_tree(args, sample: Sample, output: dict): assert not args.partial_rollout, "Currently partial rollout is not supported when using miles router" diff --git a/miles/router/router.py b/miles/router/router.py index 2e8ecfc41..f092f359a 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse from starlette.responses import Response +from miles.router.session.sessions import setup_session_routes from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -64,11 +65,12 @@ def __init__(self, args, verbose=False): self.app.add_middleware(middleware, router=self) def _setup_routes(self): - """Setup all the HTTP routes""" + """Setup all the HTTP routes except catch-all proxy""" # sglang-router api self.app.post("/add_worker")(self.add_worker) self.app.get("/list_workers")(self.list_workers) - self.app.post("/retrieve_from_text")(self.retrieve_from_text) + # Session routes - must be registered before catch-all + setup_session_routes(self.app, self) # Catch-all route for proxying to SGLang - must be registered LAST self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) @@ -130,39 +132,51 @@ async def _health_check_loop(self): async def proxy(self, request: Request, path: str): """Proxy all other requests to the SGLang router""" - # Forward all other paths to SGLang router + result = await self._do_proxy(request, path) + return self._build_proxy_response(result) + + async def _do_proxy( + self, + request: Request, + path: str, + body: bytes | None = None, + headers: dict | None = None, + ) -> dict: + """Core proxy logic. Returns dict with request_body, response_body, status_code, headers.""" worker_url = self._use_url() url = f"{worker_url}/{path}" - # Get request body and headers - body = await request.body() - headers = dict(request.headers) + if body is None: + body = await request.body() + if headers is None: + headers = dict(request.headers) + if body is not None: + headers = {k: v for k, v in headers.items() if k.lower() not in ("content-length", "transfer-encoding")} try: response = await self.client.request(request.method, url, content=body, headers=headers) - # Eagerly read content so we can return JSON (not streaming) content = await response.aread() - content_type = response.headers.get("content-type", "") - try: - # Prefer parsing JSON if possible - data = json.loads(content) - return JSONResponse( - content=data, - status_code=response.status_code, - headers=dict(response.headers), - ) - except Exception: - # Fall back to raw body with original content type - return Response( - content=content, - status_code=response.status_code, - headers=dict(response.headers), - media_type=content_type or None, - ) - + return { + "request_body": body, + "response_body": content, + "status_code": response.status_code, + "headers": dict(response.headers), + } finally: self._finish_url(worker_url) + def _build_proxy_response(self, result: dict) -> Response: + """Build HTTP response from proxy result.""" + content = result["response_body"] + status_code = result["status_code"] + headers = result["headers"] + content_type = headers.get("content-type", "") + try: + data = json.loads(content) + return JSONResponse(content=data, status_code=status_code, headers=headers) + except Exception: + return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) + async def add_worker(self, request: Request): """Add a new worker to the router. Supports providing the URL via query string or JSON body. @@ -197,28 +211,6 @@ async def list_workers(self, request: Request): """List all registered workers""" return {"urls": list(self.worker_request_counts.keys())} - async def retrieve_from_text(self, request: Request): - """Get token information from text input""" - body = await request.body() - payload = json.loads(body) if body else {} - - text = payload.get("text", "") - - # Use radix tree's retrieve_from_text method (no need to fetch weight version here) - token_ids, logp, loss_mask = self.radix_tree.retrieve_from_text(text, return_logprob=True) - - # Handle the result based on whether logp was requested - result = { - "tokens": token_ids, # token IDs - "response": text, # The input text - "loss_mask": loss_mask, # Loss mask for the tokens - "token_length": len(token_ids), - "loss_mask_length": len(loss_mask), - "rollout_logp": logp, - } - - return result - def _use_url(self): """Select worker URL with minimal active requests.""" diff --git a/miles/router/session/seq_trajectory.py b/miles/router/session/seq_trajectory.py new file mode 100644 index 000000000..26c16fc15 --- /dev/null +++ b/miles/router/session/seq_trajectory.py @@ -0,0 +1,263 @@ +import copy +import logging +import uuid +from typing import Any + +from pydantic import BaseModel, Field +from transformers import AutoTokenizer + +from miles.rollout.generate_utils.tokenize_utils import tokenize_messages +from miles.utils.chat_message_utils import calc_last_think_part_index + +logger = logging.getLogger(__name__) + + +class TokenInfo(BaseModel): + tokens: list[str] = Field(default_factory=list) + token_ids: list[int] = Field(default_factory=list) + log_probs: list[float] = Field(default_factory=list) + loss_mask: list[int] = Field(default_factory=list) + + def remove_tokens(self, start_index: int, end_index: int): + # Notice: the end index is exclusive. + self.tokens = self.tokens[start_index:end_index] + self.token_ids = self.token_ids[start_index:end_index] + self.log_probs = self.log_probs[start_index:end_index] + self.loss_mask = self.loss_mask[start_index:end_index] + + def insert_tokens(self, tokens: list[str], token_ids: list[int], log_probs: list[float], loss_mask: list[int]): + self.tokens.extend(tokens) + self.token_ids.extend(token_ids) + self.log_probs.extend(log_probs) + self.loss_mask.extend(loss_mask) + + def append(self, token: str, token_id: int, log_prob: float, loss_mask: int): + self.tokens.append(token) + self.token_ids.append(token_id) + self.log_probs.append(log_prob) + self.loss_mask.append(loss_mask) + + def __add__(self, other: "TokenInfo") -> "TokenInfo": + return TokenInfo( + tokens=self.tokens + other.tokens, + token_ids=self.token_ids + other.token_ids, + log_probs=self.log_probs + other.log_probs, + loss_mask=self.loss_mask + other.loss_mask, + ) + + @staticmethod + def remove_last_assistant_think_and_handle_truncated_message( + token_info: "TokenInfo", model_name: str + ) -> "TokenInfo": + raise NotImplementedError("Not implemented yet.") + tmp = copy.deepcopy(token_info) + start, end = calc_last_think_part_index(tmp.token_ids, model_name) + if start is None: + # No think part found, or think part is truncated, we will not trim. + return tmp + # Notice: after trimming, the old answer tokens cannot be used to calculate loss, so logp and loss mask are set to 0. + if end is not None: + tmp.remove_tokens(start, end + 1) + if end + 1 < len(token_info.token_ids): + n = len(token_info.token_ids) + tmp.insert_tokens( + token_info.tokens[end + 1 :], + token_info.token_ids[end + 1 :], + [0.0] * (n - end - 1), + [0] * (n - end - 1), + ) + # Handle truncated message. + + return tmp + + +class Turn(BaseModel): + """ + A turn is a multiple message turn, end with an assistant message. + """ + + messages: list[dict[str, Any]] + prompt_tokens: TokenInfo + response_tokens: TokenInfo + + def __init__( + self, + messages: list[dict[str, Any]], + prompt_tokens: TokenInfo, + response_tokens: TokenInfo, + ): + super().__init__( + messages=messages, + prompt_tokens=prompt_tokens, + response_tokens=response_tokens, + ) + assert ( + len(messages) > 0 and messages[-1]["role"] == "assistant" + ), "The last message must be an assistant message." + + def match_prefix_messages_and_return_remaining(self, other: list[dict[str, Any]]) -> list[dict[str, Any]] | None: + """ + If the messages match with other's prefix, return the remaining messages. Otherwise, return None. + """ + if len(self.messages) < len(other): + return None + for i in range(len(other)): + if self.messages[i] != other[i]: + return None + return self.messages[len(other) :] + + def handle_token_out_for_next_turn(self, model_name: str) -> TokenInfo: + raise NotImplementedError("Not implemented yet.") + trimmed_tokens = TokenInfo.remove_last_assistant_think(self.prompt_tokens + self.response_tokens, model_name) + return trimmed_tokens + + +class SeqTrajectory(BaseModel): + """ + Sequence trajectory state. + Can only maintain the token info for the last turn. + It should not have any state. Which means `token_ids` should always include the final chat templated text. + (Note: if seq trajectory has state, when a reqeust crash, bug will happen.) + """ + + num_turns: int = 0 + model_name: str = "" + # History for all turns. + turns: list[Turn] = Field(default_factory=list) + records: list[dict[str, Any]] = Field(default_factory=list) + + def insert_new_turn(self, turn: Turn): + self.turns.append(turn) + self.num_turns += 1 + + def insert_new_record(self, record: dict[str, Any]): + self.records.append(record) + + def match_prefix_turns_and_return_last_turn( + self, messages: list[dict[str, Any]], n: int | None = None + ) -> tuple[Turn, list[dict[str, Any]]]: + if n is None: + n = self.num_turns + assert n > 0, "n must be greater than 0" + remain_messages = messages + for i in range(n): + turn = self.turns[i] + remain_messages = turn.match_prefix_messages_and_return_remaining(remain_messages) + if remain_messages is None: + raise ValueError( + "Under sequence trajectory, messages prefix should match, but unmatched messages: {remain_messages}" + ) + return self.turns[n - 1], remain_messages + + def calc_prompt_tokens_info( + self, + messages: list[dict[str, Any]], + tokenizer: AutoTokenizer, + cross_turn_token_out: bool = False, + inherit_last_assistant: bool = False, + ) -> TokenInfo: + if cross_turn_token_out and self.num_turns > 0: + if inherit_last_assistant: + raise NotImplementedError("Not implemented yet.") + turn, remain_messages = self.match_prefix_turns_and_return_last_turn(messages) + token_info = turn.handle_token_out_for_next_turn(self.model_name) + else: + if self.num_turns >= 2: + turn, remain_messages = self.match_prefix_turns_and_return_last_turn(messages, self.num_turns - 1) + old_token_ids = turn.prompt_tokens.token_ids + turn.response_tokens.token_ids + else: + remain_messages = messages + old_token_ids = [] + new_token_ids = tokenize_messages(remain_messages, tokenizer, add_generation_prompt=True) + token_ids = old_token_ids + new_token_ids + log_probs = [0.0] * len(token_ids) + loss_mask = [0] * len(token_ids) + token_info = TokenInfo( + tokens=tokenizer.convert_ids_to_tokens(token_ids), + token_ids=token_ids, + log_probs=log_probs, + loss_mask=loss_mask, + ) + else: + token_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + log_probs = [0.0] * len(token_ids) + loss_mask = [0] * len(token_ids) + token_info = TokenInfo( + tokens=tokenizer.convert_ids_to_tokens(token_ids), + token_ids=token_ids, + log_probs=log_probs, + loss_mask=loss_mask, + ) + + return token_info + + def get_last_turn_token_info(self) -> TokenInfo: + if not self.turns: + return TokenInfo() + return self.turns[-1].prompt_tokens + self.turns[-1].response_tokens + + +class SeqTrajectoryManager: + def __init__(self, args, tokenizer: AutoTokenizer): + self.sessions: dict[str, SeqTrajectory] = {} + self.args = args + self.tokenizer = tokenizer + + def create_session(self) -> str: + session_id = uuid.uuid4().hex + self.sessions[session_id] = SeqTrajectory() + return session_id + + def get_session_by_id(self, session_id: str) -> TokenInfo | None: + session = self.sessions.get(session_id) + if session is None: + return None + return session.get_last_turn_token_info() + + def get_session_records(self, session_id: str) -> list[dict[str, Any]] | None: + session = self.sessions.get(session_id) + if session is None: + return None + return session.records + + def calc_prompt_tokens(self, session_id: str, messages: list[dict[str, Any]]) -> TokenInfo | None: + # Notice: Sequence trajectory manager will support the prefix of input messages match with the only history. + session = self.sessions.get(session_id) + if session is None: + return None + cross_turn_token_out = getattr(self.args, "cross_turn_token_out", False) + inherit_last_assistant = getattr(self.args, "inherit_last_assistant", False) + token_info: TokenInfo = session.calc_prompt_tokens_info( + messages, + self.tokenizer, + cross_turn_token_out=cross_turn_token_out, + inherit_last_assistant=inherit_last_assistant, + ) + return token_info + # if remain_messages is None: + # TODO(jiajun): Should we truncate think part of the last turn's assistant message, if the new turn does not include any new message? + # Turn 1: sys | user | assistant | tool | assistant + # Turn 2: sys | user | assistant | tool | assistant | ??? + # Noral: sys | user | assistant | tool | assistant | ??? + # Not hard to fix, but temporarily leave this TODO. + # raise ValueError("Currently, we do not support consecutive assistant message input.") + + def delete_session_by_id(self, session_id: str) -> bool: + session = self.sessions.pop(session_id, None) + if session is None: + return False + return True + + def add_record(self, session_id: str, turn: Turn) -> bool: + session = self.sessions.get(session_id) + if session is None: + raise ValueError(f"Session {session_id} not found.") + session.insert_new_turn(turn) + return True + + def add_session_record(self, session_id: str, record: dict[str, Any]) -> bool: + session = self.sessions.get(session_id) + if session is None: + raise ValueError(f"Session {session_id} not found.") + session.insert_new_record(record) + return True diff --git a/miles/router/session/sessions.py b/miles/router/session/sessions.py new file mode 100644 index 000000000..6573127a3 --- /dev/null +++ b/miles/router/session/sessions.py @@ -0,0 +1,115 @@ +import json +import time +from typing import TYPE_CHECKING + +from fastapi import Request +from fastapi.responses import JSONResponse, Response +from pydantic import BaseModel +from transformers import AutoTokenizer + +from miles.router.session.seq_trajectory import SeqTrajectoryManager, TokenInfo, Turn + +if TYPE_CHECKING: + from miles.router.router import MilesRouter + + +class SessionRecord(BaseModel): + timestamp: float + method: str + path: str + request: dict + response: dict + status_code: int + + +class GetSessionResponse(BaseModel): + session_id: str + records: dict + session_records: list[SessionRecord] | None = None + + +def setup_session_routes(app, router: "MilesRouter"): + + # TODO temporary hack before @guapisolo implements TITO + # ============================= HACK START =============================== + # Lazy load tokenizer only when needed (for tests that don't have hf_checkpoint) + tokenizer = AutoTokenizer.from_pretrained(router.args.hf_checkpoint, trust_remote_code=True) + manager = SeqTrajectoryManager(router.args, tokenizer) + + # ============================= HACK END =============================== + + @app.post("/sessions") + async def create_session(): + session_id = manager.create_session() + return {"session_id": session_id} + + @app.get("/sessions/{session_id}") + async def get_session(session_id: str): + token_info = manager.get_session_by_id(session_id) + if token_info is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + session_records = manager.get_session_records(session_id) + return GetSessionResponse( + session_id=session_id, + records=token_info.model_dump(), + session_records=session_records, + ) + + @app.delete("/sessions/{session_id}") + async def delete_session(session_id: str): + status = manager.delete_session_by_id(session_id) + if not status: + return JSONResponse(status_code=404, content={"error": "session not found"}) + return Response(status_code=204) + + @app.api_route("/sessions/{session_id}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) + async def session_proxy(request: Request, session_id: str, path: str): + body = await request.body() + request_body = json.loads(body) if body else {} + + prompt_token_info = TokenInfo() + response_token_info = TokenInfo() + if "messages" in request_body and "input_ids" not in request_body: + prompt_token_info = manager.calc_prompt_tokens(session_id, request_body["messages"]) + if prompt_token_info is None: + return JSONResponse(status_code=404, content={"error": "session not found"}) + token_ids = prompt_token_info.token_ids + request_body["input_ids"] = token_ids + body = json.dumps(request_body).encode("utf-8") + + result = await router._do_proxy(request, path, body=body) + + response = json.loads(result["response_body"]) + + choice = response.get("choices", [{}])[0] + messages = request_body["messages"] + [choice["message"]] + + assert "logprobs" in choice and "content" in choice["logprobs"], "logprobs must be in choice" + logprobs_content = choice["logprobs"]["content"] + + for item in logprobs_content: + if "token" in item and "token_id" not in item: + item["token_id"] = tokenizer.convert_tokens_to_ids(item["token"]) + response_token_info.append(item["token"], item["token_id"], item["logprob"], 1) + + manager.add_record( + session_id, + Turn( + messages=messages, + prompt_tokens=prompt_token_info, + response_tokens=response_token_info, + ), + ) + manager.add_session_record( + session_id, + SessionRecord( + timestamp=time.time(), + method=request.method, + path=path, + request=request_body, + response=response, + status_code=result["status_code"], + ).model_dump(), + ) + + return router._build_proxy_response(result) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 79b2c419c..071020292 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -10,8 +10,10 @@ from miles.backends.sglang_utils.arguments import add_sglang_arguments from miles.backends.sglang_utils.arguments import validate_args as sglang_validate_args +from miles.utils.environ import enable_experimental_rollout_refactor from miles.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from miles.utils.logging_utils import configure_logger +from miles.utils.misc import load_function logger = logging.getLogger(__name__) @@ -204,7 +206,11 @@ def add_rollout_arguments(parser): parser.add_argument( "--rollout-function-path", type=str, - default="miles.rollout.sglang_rollout.generate_rollout", + default=( + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn" + if enable_experimental_rollout_refactor() + else "miles.rollout.sglang_rollout.generate_rollout" + ), help=( "Path to the rollout generation function." "You should use this model to create your own custom rollout function, " @@ -1344,6 +1350,20 @@ def add_ci_arguments(parser): ) return parser + def add_user_provided_function_arguments(parser): + args_partial, _ = parser.parse_known_args() + for path in [ + args_partial.rollout_function_path, + args_partial.custom_generate_function_path, + ]: + try: + fn = load_function(path) + except (ModuleNotFoundError, ValueError): + continue + if fn is not None and callable(getattr(fn, "add_arguments", None)): + fn.add_arguments(parser) + return parser + def add_sglang_tp_size(): temp_parser = argparse.ArgumentParser(add_help=False) temp_parser.add_argument("--rollout-num-gpus-per-engine", type=int, default=1) @@ -1374,6 +1394,8 @@ def add_sglang_tp_size(): parser = add_prefill_decode_disaggregation_arguments(parser) parser = add_ci_arguments(parser) parser = add_custom_megatron_plugins_arguments(parser) + if enable_experimental_rollout_refactor(): + parser = add_user_provided_function_arguments(parser) reset_arg( parser, "--custom-config-path", diff --git a/miles/utils/chat_message_utils.py b/miles/utils/chat_message_utils.py new file mode 100644 index 000000000..815b1dca4 --- /dev/null +++ b/miles/utils/chat_message_utils.py @@ -0,0 +1,39 @@ +# These are helper functions for think token lookup. +THINK_TOKEN_START = { + "qwen3": ("", 151667), +} +THINK_TOKEN_END = { + "qwen3": ("", 151668), +} + + +def get_think_token_start(model_name: str) -> tuple[str, int]: + return THINK_TOKEN_START[model_name] + + +def get_think_token_end(model_name: str) -> tuple[str, int]: + return THINK_TOKEN_END[model_name] + + +def calc_last_think_part_index(tokens: list[int], model_name: str) -> tuple[int | None, int | None]: + start_index = None + end_index = None + for i in range(len(tokens)): + if tokens[i] == get_think_token_start(model_name)[1]: + start_index = i + + if start_index is None: + # No think tokens found, no strip. + return None, None + + for i in range(start_index + 1, len(tokens)): + if tokens[i] == get_think_token_end(model_name)[1]: + end_index = i + + # If think part being truncated, end_index would be None. + return start_index, end_index + + +def check_is_truncated_message(tokens: list[int], model_name: str) -> bool: + # TODO: handle this later. + pass diff --git a/miles/utils/environ.py b/miles/utils/environ.py new file mode 100644 index 000000000..35d1f350e --- /dev/null +++ b/miles/utils/environ.py @@ -0,0 +1,14 @@ +import os + +_printed_experimental_rollout_refactor = False + + +def enable_experimental_rollout_refactor() -> bool: + result = bool(int(os.environ.get("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", "0"))) + + global _printed_experimental_rollout_refactor + if result and not _printed_experimental_rollout_refactor: + print("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR=1 is enabled (experimental feature)") + _printed_experimental_rollout_refactor = True + + return result diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 2b3e6e192..0abdbbf59 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -162,11 +162,15 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60): +async def _post(client, url, payload, max_retries=60, action="post"): retry_count = 0 while retry_count < max_retries: try: - response = await client.post(url, json=payload or {}) + if action in ("delete", "get"): + assert not payload + response = await getattr(client, action)(url) + else: + response = await getattr(client, action)(url, json=payload or {}) response.raise_for_status() try: output = response.json() @@ -240,8 +244,8 @@ def __init__(self, concurrency: int): timeout=httpx.Timeout(None), ) - async def do_post(self, url, payload, max_retries=60): - return await _post(self._client, url, payload, max_retries) + async def do_post(self, url, payload, max_retries=60, action="post"): + return await _post(self._client, url, payload, max_retries, action=action) # Create actors per node created = [] @@ -265,7 +269,8 @@ async def do_post(self, url, payload, max_retries=60): _post_actors = created -async def post(url, payload, max_retries=60): +# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) +async def post(url, payload, max_retries=60, action="post"): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: @@ -274,15 +279,16 @@ async def post(url, payload, max_retries=60): actor = _next_actor() if actor is not None: # Use a thread to avoid blocking the event loop on ray.get - obj_ref = actor.do_post.remote(url, payload, max_retries) + obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) return await asyncio.to_thread(ray.get, obj_ref) except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries) + return await _post(_http_client, url, payload, max_retries, action=action) +# TODO unify w/ `post` to add retries and remote-execution async def get(url): response = await _http_client.get(url) response.raise_for_status() diff --git a/miles/utils/misc.py b/miles/utils/misc.py index c0a96d636..bae72ec0d 100644 --- a/miles/utils/misc.py +++ b/miles/utils/misc.py @@ -1,17 +1,55 @@ +import asyncio import importlib import subprocess +from contextlib import contextmanager import ray from miles.utils.http_utils import is_port_available +# Mainly used for test purpose where `load_function` needs to load many in-flight generated functions +class FunctionRegistry: + def __init__(self): + self._registry: dict[str, object] = {} + + @contextmanager + def temporary(self, name: str, fn: object): + self._register(name, fn) + try: + yield + finally: + self._unregister(name) + + def get(self, name: str) -> object | None: + return self._registry.get(name) + + def _register(self, name: str, fn: object) -> None: + assert name not in self._registry + self._registry[name] = fn + + def _unregister(self, name: str) -> None: + assert name in self._registry + self._registry.pop(name) + + +function_registry = FunctionRegistry() + + +# TODO may rename to `load_object` since it can be used to load things like tool_specs def load_function(path): """ - Load a function from a module. + Load a function from registry or module. :param path: The path to the function, e.g. "module.submodule.function". :return: The function object. """ + if path is None: + return None + + registered = function_registry.get(path) + if registered is not None: + return registered + module_path, _, attr = path.rpartition(".") module = importlib.import_module(module_path) return getattr(module, attr) @@ -30,8 +68,9 @@ def __call__(cls, *args, **kwargs): cls._instances[cls] = instance return cls._instances[cls] - def clear_instances(cls): - cls._instances = {} + @staticmethod + def clear_all_instances(): + SingletonMeta._instances.clear() def exec_command(cmd: str, capture_output: bool = False) -> str | None: @@ -92,3 +131,8 @@ def should_run_periodic_action( step = rollout_id + 1 return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0) + + +async def as_completed_async(tasks): + for coro in asyncio.as_completed(tasks): + yield await coro diff --git a/miles/utils/test_utils/__init__.py b/miles/utils/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/miles/utils/test_utils/mock_sglang_server.py b/miles/utils/test_utils/mock_sglang_server.py new file mode 100644 index 000000000..7602230d6 --- /dev/null +++ b/miles/utils/test_utils/mock_sglang_server.py @@ -0,0 +1,250 @@ +import asyncio +import re +import time +import uuid +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import asdict, dataclass + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.function_call_parser import FunctionCallParser +from transformers import AutoTokenizer + +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +@dataclass(frozen=True) +class ProcessResultMetaInfo: + weight_version: str | None = None + routed_experts: str | None = None + spec_accept_token_num: int | None = None + spec_draft_token_num: int | None = None + spec_verify_ct: int | None = None + + def to_dict(self) -> dict: + return {k: v for k, v in asdict(self).items() if v is not None} + + +@dataclass(frozen=True) +class ProcessResult: + text: str + finish_reason: str = "stop" + cached_tokens: int = 0 + meta_info: ProcessResultMetaInfo = ProcessResultMetaInfo() + + +ProcessFn = Callable[[str], ProcessResult] + + +class MockSGLangServer: + def __init__( + self, + model_name: str, + process_fn: ProcessFn, + host: str, + port: int, + latency: float = 0.0, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.process_fn = process_fn + self.host = host + self.port = port or find_available_port(30000) + self.latency = latency + + self.app = FastAPI() + self._server: UvicornThreadServer | None = None + + self.request_log: list[dict] = [] + self._concurrency = Counter() + + self._setup_routes() + + @property + def max_concurrent(self) -> int: + return self._concurrency.max_value + + def reset_stats(self): + self.request_log.clear() + self._concurrency.reset() + + def start(self): + self._server = UvicornThreadServer(self.app, host=self.host, port=self.port) + self._server.start() + + def stop(self): + if self._server is not None: + self._server.stop() + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + def _setup_routes(self): + @self.app.post("/generate") + async def generate(request: Request): + return await self._handle_generate_like_request(request, self._compute_generate_response) + + @self.app.post("/v1/chat/completions") + async def chat_completions(request: Request): + return await self._handle_generate_like_request(request, self._compute_chat_completions_response) + + @self.app.get("/health") + async def health(): + return JSONResponse(content={"status": "ok"}) + + @self.app.post("/abort_request") + async def abort_request(_request: Request): + return JSONResponse(content={"status": "ok"}) + + async def _handle_generate_like_request(self, request: Request, compute_fn: Callable[[dict], dict]): + payload = await request.json() + self.request_log.append(payload) + with self._concurrency.track(): + if self.latency > 0: + await asyncio.sleep(self.latency) + response = compute_fn(payload) + return JSONResponse(content=response) + + def _compute_generate_response(self, payload: dict) -> dict: + assert payload.get("return_logprob", True) is True, "MockSGLangServer requires return_logprob=True" + input_ids = payload.get("input_ids", []) + + prompt_str = self.tokenizer.decode(input_ids, skip_special_tokens=False) + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + prompt_tokens = len(input_ids) + completion_tokens = len(output_ids) + + finish_reason_dict = {"type": process_result.finish_reason} + if process_result.finish_reason == "length": + finish_reason_dict["length"] = completion_tokens + + output_token_logprobs = [(-1 / 128 * i, token_id) for i, token_id in enumerate(output_ids)] + + meta_info = { + "finish_reason": finish_reason_dict, + "prompt_tokens": prompt_tokens, + "cached_tokens": process_result.cached_tokens, + "completion_tokens": completion_tokens, + "output_token_logprobs": output_token_logprobs, + **process_result.meta_info.to_dict(), + } + + return {"text": process_result.text, "meta_info": meta_info} + + def _compute_chat_completions_response(self, payload: dict) -> dict: + messages = payload.get("messages", []) + tools = payload.get("tools") + + prompt_str = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, tools=tools + ) + + print(f"_messages: {messages=}", flush=True) + print(f"_compute_chat_completions_response: {prompt_str=}", flush=True) + process_result = self.process_fn(prompt_str) + output_ids = self.tokenizer.encode(process_result.text, add_special_tokens=False) + + logprobs_content = [ + {"token": self.tokenizer.convert_ids_to_tokens(tid), "logprob": -1 / 128 * i} + for i, tid in enumerate(output_ids) + ] + + finish_reason = process_result.finish_reason + tool_calls = None + if tools and finish_reason == "stop": + parser = FunctionCallParser( + tools=TypeAdapter(list[Tool]).validate_python(tools), + tool_call_parser="qwen25", + ) + message_content, parsed_calls = parser.parse_non_stream(process_result.text) + if parsed_calls: + finish_reason = "tool_calls" + tool_calls = [ + { + "id": f"call{i:05d}", + "type": "function", + "function": {"name": call.name, "arguments": call.parameters or "{}"}, + } + for i, call in enumerate(parsed_calls) + ] + else: + message_content = process_result.text + + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message_content, + "tool_calls": tool_calls, + }, + "logprobs": {"content": logprobs_content}, + "finish_reason": finish_reason, + } + ], + } + + +class Counter: + def __init__(self): + self._current = 0 + self._max = 0 + + @property + def max_value(self) -> int: + return self._max + + def reset(self): + self._current = 0 + self._max = 0 + + @contextmanager + def track(self): + self._current += 1 + self._max = max(self._max, self._current) + try: + yield + finally: + self._current -= 1 + + +def default_process_fn(prompt: str) -> ProcessResult: + match = re.search(r"What is 1\+(\d+)\?", prompt) + if match: + num = int(match.group(1)) + ans = 1 + num + return ProcessResult(text=f"\\boxed{{{ans}}}", finish_reason="stop") + return ProcessResult(text="I don't understand.", finish_reason="stop") + + +@contextmanager +def with_mock_server( + model_name: str = "Qwen/Qwen3-0.6B", + process_fn: ProcessFn = default_process_fn, + host: str = "127.0.0.1", + port: int | None = None, + latency: float = 0.0, +): + server = MockSGLangServer( + model_name=model_name, + process_fn=process_fn, + host=host, + port=port, + latency=latency, + ) + try: + server.start() + yield server + finally: + server.stop() diff --git a/miles/utils/test_utils/mock_tools.py b/miles/utils/test_utils/mock_tools.py new file mode 100644 index 000000000..26c18738e --- /dev/null +++ b/miles/utils/test_utils/mock_tools.py @@ -0,0 +1,367 @@ +import json +import logging + +from transformers import AutoTokenizer + +from miles.utils.test_utils.mock_sglang_server import ProcessResult + +logger = logging.getLogger(__name__) +SAMPLE_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_year", + "description": "Get current year", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_temperature", + "description": "Get temperature for a location", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + }, +] + + +def _get_year(params: dict) -> str: + assert len(params) == 0 + return json.dumps({"year": 2026}) + + +def _get_temperature(params: dict) -> str: + temps = {"Mars": -60, "Earth": 15} + location = params.get("location") + assert location in temps, f"Unknown location: {location}" + return json.dumps({"temperature": temps[location]}) + + +TOOL_EXECUTORS = { + "get_year": _get_year, + "get_temperature": _get_temperature, +} + + +async def execute_tool_call(name: str, params: dict) -> str: + return TOOL_EXECUTORS[name](params) + + +_SYSTEM_PROMPT = ( + "<|im_start|>system\n" + "# Tools\n" + "\n" + "You may call one or more functions to assist with the user query.\n" + "\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n" + "\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" +) + + +_TOKENIZER = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) + + +class TwoTurnStub: + """Stub for 2-turn: get_year + get_temperature(Mars) -> final answer""" + + USER_QUESTION = "What is 42 + year + temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_RESPONSE = "The answer is: 42 + 2026 + -60 = 2008." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + TwoTurnStub.FIRST_PROMPT: TwoTurnStub.FIRST_RESPONSE, + TwoTurnStub.SECOND_PROMPT: TwoTurnStub.SECOND_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") + + +class ThreeTurnStub: + """Stub for 3-turn: get_year + get_temperature(Mars) -> get_temperature(Earth) -> final answer""" + + USER_QUESTION = "What is 42 + year + Mars temperature + Earth temperature?" + + FIRST_RESPONSE = ( + "Let me get the year and Mars temperature first.\n" + "\n" + '{"name": "get_year", "arguments": {}}\n' + "\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Mars"}}\n' + "<|im_end|>\n" + ) + + SECOND_RESPONSE = ( + "Now let me get Earth temperature.\n" + "\n" + '{"name": "get_temperature", "arguments": {"location": "Earth"}}\n' + "<|im_end|>\n" + ) + + FIRST_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": -60}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + SECOND_TOOL_RESPONSE = ( + "<|im_start|>user\n" + "\n" + '{"temperature": 15}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + THIRD_RESPONSE = "The answer is: 42 + 2026 + -60 + 15 = 2023." + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + THIRD_PROMPT = SECOND_PROMPT + SECOND_RESPONSE + SECOND_TOOL_RESPONSE + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + THIRD_PROMPT_TOKEN_IDS = _TOKENIZER(THIRD_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = "Let me get the year and Mars temperature first." + FIRST_TOOL_CALLS_OPENAI_FORMAT = [ + {"id": "call00000", "function": {"arguments": "{}", "name": "get_year"}, "type": "function"}, + { + "id": "call00001", + "function": {"arguments": '{"location": "Mars"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + SECOND_RESPONSE_CONTENT = "Now let me get Earth temperature." + SECOND_TOOL_CALLS_OPENAI_FORMAT = [ + { + "id": "call00000", + "function": {"arguments": '{"location": "Earth"}', "name": "get_temperature"}, + "type": "function", + }, + ] + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT = OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT + [ + { + "content": SECOND_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": SECOND_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"temperature": 15}', "name": "get_temperature"}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + ThreeTurnStub.FIRST_PROMPT: ThreeTurnStub.FIRST_RESPONSE, + ThreeTurnStub.SECOND_PROMPT: ThreeTurnStub.SECOND_RESPONSE, + ThreeTurnStub.THIRD_PROMPT: ThreeTurnStub.THIRD_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") + + +class ThinkingThreeTurnStub: + """3-turn stub with a think tag in the assistant response.""" + + USER_QUESTION = ThreeTurnStub.USER_QUESTION + THINK_PREFIX = "\nLet me think.\n\n\n" + FOURTH_USER_MESSAGE = "Thanks." + + FIRST_RESPONSE = THINK_PREFIX + ThreeTurnStub.FIRST_RESPONSE + SECOND_RESPONSE = ThreeTurnStub.SECOND_RESPONSE + THIRD_RESPONSE = ThreeTurnStub.THIRD_RESPONSE + FOURTH_RESPONSE = "You're welcome." + + FIRST_TOOL_RESPONSE = ThreeTurnStub.FIRST_TOOL_RESPONSE + SECOND_TOOL_RESPONSE = ThreeTurnStub.SECOND_TOOL_RESPONSE + + FIRST_PROMPT = _SYSTEM_PROMPT + "<|im_start|>user\n" + USER_QUESTION + "<|im_end|>\n" + "<|im_start|>assistant\n" + SECOND_PROMPT = FIRST_PROMPT + FIRST_RESPONSE + FIRST_TOOL_RESPONSE + THIRD_PROMPT = SECOND_PROMPT + SECOND_RESPONSE + SECOND_TOOL_RESPONSE + FOURTH_PROMPT = ( + THIRD_PROMPT + + THIRD_RESPONSE + + "<|im_end|>\n" + + "<|im_start|>user\n" + + FOURTH_USER_MESSAGE + + "<|im_end|>\n" + + "<|im_start|>assistant\n" + ) + + PROMPT = [{"role": "user", "content": USER_QUESTION}] + + FIRST_PROMPT_TOKEN_IDS = _TOKENIZER(FIRST_PROMPT, add_special_tokens=False)["input_ids"] + SECOND_PROMPT_TOKEN_IDS = _TOKENIZER(SECOND_PROMPT, add_special_tokens=False)["input_ids"] + THIRD_PROMPT_TOKEN_IDS = _TOKENIZER(THIRD_PROMPT, add_special_tokens=False)["input_ids"] + FOURTH_PROMPT_TOKEN_IDS = _TOKENIZER(FOURTH_PROMPT, add_special_tokens=False)["input_ids"] + + FIRST_RESPONSE_CONTENT = THINK_PREFIX + ThreeTurnStub.FIRST_RESPONSE_CONTENT + FIRST_TOOL_CALLS_OPENAI_FORMAT = ThreeTurnStub.FIRST_TOOL_CALLS_OPENAI_FORMAT + SECOND_RESPONSE_CONTENT = ThreeTurnStub.SECOND_RESPONSE_CONTENT + SECOND_TOOL_CALLS_OPENAI_FORMAT = ThreeTurnStub.SECOND_TOOL_CALLS_OPENAI_FORMAT + + OPENAI_MESSAGES_FIRST_TURN = [{"role": "user", "content": USER_QUESTION}] + + OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT = OPENAI_MESSAGES_FIRST_TURN + [ + { + "content": FIRST_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": FIRST_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"year": 2026}', "name": "get_year"}, + {"role": "tool", "tool_call_id": "call00001", "content": '{"temperature": -60}', "name": "get_temperature"}, + ] + + OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT = OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT + [ + { + "content": SECOND_RESPONSE_CONTENT, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": SECOND_TOOL_CALLS_OPENAI_FORMAT, + }, + {"role": "tool", "tool_call_id": "call00000", "content": '{"temperature": 15}', "name": "get_temperature"}, + ] + OPENAI_MESSAGES_FOURTH_TURN_FROM_CLIENT = OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT + [ + { + "content": THIRD_RESPONSE, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": None, + }, + {"role": "user", "content": FOURTH_USER_MESSAGE}, + ] + + @staticmethod + def process_fn(prompt: str) -> ProcessResult: + prompt_response_pairs = { + ThinkingThreeTurnStub.FIRST_PROMPT: ThinkingThreeTurnStub.FIRST_RESPONSE, + ThinkingThreeTurnStub.SECOND_PROMPT: ThinkingThreeTurnStub.SECOND_RESPONSE, + ThinkingThreeTurnStub.THIRD_PROMPT: ThinkingThreeTurnStub.THIRD_RESPONSE, + ThinkingThreeTurnStub.FOURTH_PROMPT: ThinkingThreeTurnStub.FOURTH_RESPONSE, + } + + for expect_prompt, response in prompt_response_pairs.items(): + if prompt == expect_prompt: + return ProcessResult(text=response, finish_reason="stop") + + raise ValueError(f"Unexpected {prompt=}") diff --git a/miles/utils/test_utils/uvicorn_thread_server.py b/miles/utils/test_utils/uvicorn_thread_server.py new file mode 100644 index 000000000..904343c98 --- /dev/null +++ b/miles/utils/test_utils/uvicorn_thread_server.py @@ -0,0 +1,49 @@ +import asyncio +import socket +import threading +import time + +import uvicorn + + +class UvicornThreadServer: + def __init__(self, app, host: str, port: int): + self._app = app + self.host = host + self.port = port + self._server: uvicorn.Server | None = None + self._thread: threading.Thread | None = None + + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" + + def start(self) -> None: + config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="info") + self._server = uvicorn.Server(config) + + def run() -> None: + asyncio.run(self._server.serve()) + + self._thread = threading.Thread(target=run, daemon=True) + self._thread.start() + self._wait_for_port_open() + + def stop(self) -> None: + if self._server is not None: + self._server.should_exit = True + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=2.0) + + def _wait_for_port_open(self) -> None: + for _ in range(50): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex((self.host, self.port)) + sock.close() + if result == 0: + return + except Exception: + pass + time.sleep(0.1) + raise RuntimeError(f"Failed to start server on {self.url}") diff --git a/miles/utils/types.py b/miles/utils/types.py index 0a2531a7a..5200d625e 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -145,6 +145,24 @@ def get_reward_value(self, args) -> float: def effective_response_length(self): return sum(self.loss_mask) if self.loss_mask is not None else self.response_length + def validate(self): + assert self.response_length >= 0, f"response_length must be >= 0, got {self.response_length}" + assert ( + len(self.tokens) >= self.response_length + ), f"tokens length ({len(self.tokens)}) must be >= response_length ({self.response_length})" + if self.loss_mask is not None: + assert ( + len(self.loss_mask) == self.response_length + ), f"loss_mask length ({len(self.loss_mask)}) != response_length ({self.response_length})" + if self.rollout_log_probs is not None: + assert ( + len(self.rollout_log_probs) == self.response_length + ), f"rollout_log_probs length ({len(self.rollout_log_probs)}) != response_length ({self.response_length})" + if self.rollout_routed_experts is not None: + actual = len(self.rollout_routed_experts) + expect = len(self.tokens) - 1 + assert actual == expect, f"rollout_routed_experts length ({actual}) != len(tokens) - 1 ({expect})" + def update_from_meta_info(self, args, meta_info: dict): """ Update the sample with new information from meta_info returned by the rollout engine. diff --git a/requirements.txt b/requirements.txt index 2c20195fc..dacd51132 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf pillow +pybase64 pylatexenc pyyaml qwen_vl_utils # for VLM diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/ci/gpu_lock_exec.py b/tests/ci/gpu_lock_exec.py index 9507e2e85..20379f76a 100644 --- a/tests/ci/gpu_lock_exec.py +++ b/tests/ci/gpu_lock_exec.py @@ -19,11 +19,14 @@ def main(): _execute_print_only(args) return - fd_locks = _try_acquire(args) + if args.count == 0 and not args.devices: + print("[gpu_lock_exec] Do not acquire GPU since count=0", flush=True) + else: + fd_locks = _try_acquire(args) - dev_list = ",".join(str(x.gpu_id) for x in fd_locks) - os.environ[args.target_env_name] = dev_list - print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) + dev_list = ",".join(str(x.gpu_id) for x in fd_locks) + os.environ[args.target_env_name] = dev_list + print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) _os_execvp(args) diff --git a/tests/e2e/.gitkeep b/tests/e2e/.gitkeep new file mode 100644 index 000000000..615f2b076 --- /dev/null +++ b/tests/e2e/.gitkeep @@ -0,0 +1 @@ +# TODO: may move e2e tests to this folder \ No newline at end of file diff --git a/tests/fast/__init__.py b/tests/fast/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/conftest.py b/tests/fast/conftest.py new file mode 100644 index 000000000..4cb30e91f --- /dev/null +++ b/tests/fast/conftest.py @@ -0,0 +1,15 @@ +import os + +import pytest + +from tests.fast.fixtures.generation_fixtures import generation_env +from tests.fast.fixtures.rollout_fixtures import rollout_env + +_ = rollout_env, generation_env + + +@pytest.fixture(autouse=True) +def enable_experimental_rollout_refactor(): + os.environ["MILES_EXPERIMENTAL_ROLLOUT_REFACTOR"] = "1" + yield + os.environ.pop("MILES_EXPERIMENTAL_ROLLOUT_REFACTOR", None) diff --git a/tests/fast/fixtures/__init__.py b/tests/fast/fixtures/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/fast/fixtures/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/fast/fixtures/generation_fixtures.py b/tests/fast/fixtures/generation_fixtures.py new file mode 100644 index 000000000..816371ee3 --- /dev/null +++ b/tests/fast/fixtures/generation_fixtures.py @@ -0,0 +1,274 @@ +""" +Fixtures to test custom-generate-function +""" + +from argparse import Namespace +from contextlib import contextmanager +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any +from unittest.mock import patch + +import pytest +import requests + +from miles.rollout.base_types import GenerateFnInput +from miles.rollout.inference_rollout.compatibility import load_generate_function +from miles.rollout.inference_rollout.inference_rollout_common import GenerateState +from miles.router.router import MilesRouter +from miles.utils.async_utils import run +from miles.utils.http_utils import find_available_port, init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer +from miles.utils.types import Sample + +MODEL_NAME = "Qwen/Qwen3-0.6B" +RESPONSE_TEXT = "\\boxed{8}" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} + +VARIANT_TO_GENERATE_FN_PATH = { + "old_sglang_rollout": "miles.rollout.sglang_rollout.generate", + "single_turn": "miles.rollout.generate_hub.single_turn.generate", + "multi_turn_single_sample": "miles.rollout.generate_hub.multi_turn.generate", + "multi_turn_multi_samples": "miles.rollout.generate_hub.multi_turn.generate", + "agentic_tool_call_single_sample": "miles.rollout.generate_hub.agentic_tool_call.generate", + "agentic_tool_call_multi_samples": "miles.rollout.generate_hub.agentic_tool_call.generate", +} + + +def extra_argv_for_variant( + variant: str, + *, + custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", +) -> list[str]: + argv = [ + "--custom-generate-function-path", + custom_generate_function_path or VARIANT_TO_GENERATE_FN_PATH[variant], + ] + + if variant in ( + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ): + argv += [ + "--generate-max-turns", + str(generate_max_turns), + "--generate-tool-specs-path", + generate_tool_specs_path, + "--generate-execute-tool-function-path", + generate_execute_tool_function_path, + ] + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + argv += ["--generate-tool-call-parser", generate_tool_call_parser] + if variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples"): + argv.append("--generate-multi-samples") + + return argv + + +def listify(x): + return x if isinstance(x, list) else [x] + + +def make_sample( + *, + prompt: str | list[dict] = "What is 1+7?", + tokens: list[int] | None = None, + response: str = "", + response_length: int = 0, + status: Sample.Status = Sample.Status.PENDING, + multimodal_inputs: dict | None = None, +) -> Sample: + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +@dataclass +class GenerateEnv: + args: Namespace + mock_server: Any + + +@dataclass +class GenerateResult: + sample: Sample | list[Sample] + requests: list[dict] + + +def run_generate( + env: GenerateEnv, + sample: Sample, + sampling_params: dict[str, Any] | None = None, + *, + variant: str = "single_turn", +) -> GenerateResult: + env.mock_server.request_log.clear() + result_sample = run( + _call_generate( + env.args, + sample, + sampling_params or DEFAULT_SAMPLING_PARAMS, + variant=variant, + ) + ) + return GenerateResult(sample=result_sample, requests=list(env.mock_server.request_log)) + + +async def _call_generate( + args: Namespace, + sample: Sample, + sampling_params: dict[str, Any], + *, + variant: str = "single_turn", +) -> Sample: + generate_fn = load_generate_function(VARIANT_TO_GENERATE_FN_PATH[variant]) + state = GenerateState(args) + input = GenerateFnInput(state=state, sample=sample, sampling_params=sampling_params.copy(), evaluation=False) + output = await generate_fn(input) + return output.samples + + +def make_args( + *, + variant: str, + router_port: int, + use_rollout_routing_replay: bool = False, + sglang_speculative_algorithm: str | None = None, + model_name: str = MODEL_NAME, + extra_argv: list[str] | None = None, + custom_generate_function_path: str | None = None, + generate_max_turns: int = 16, + generate_tool_specs_path: str = "miles.utils.test_utils.mock_tools.SAMPLE_TOOLS", + generate_tool_call_parser: str = "qwen25", + generate_execute_tool_function_path: str = "miles.utils.test_utils.mock_tools.execute_tool_call", + rollout_max_context_len: int | None = None, +) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + model_name, + "--prompt-data", + "/dev/null", + "--rm-type", + "math", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + if use_rollout_routing_replay: + argv.append("--use-rollout-routing-replay") + if sglang_speculative_algorithm: + argv.extend(["--sglang-speculative-algorithm", sglang_speculative_algorithm]) + if rollout_max_context_len is not None: + argv.extend(["--rollout-max-context-len", str(rollout_max_context_len)]) + + argv.extend( + extra_argv_for_variant( + variant, + custom_generate_function_path=custom_generate_function_path, + generate_max_turns=generate_max_turns, + generate_tool_specs_path=generate_tool_specs_path, + generate_tool_call_parser=generate_tool_call_parser, + generate_execute_tool_function_path=generate_execute_tool_function_path, + ) + ) + + if extra_argv: + argv.extend(extra_argv) + + from miles.utils.arguments import parse_args + + with patch("sys.argv", argv): + args = parse_args() + + init_http_client(args) + return args + + +@contextmanager +def with_miles_router(backend_url: str, model_name: str): + router_args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + hf_checkpoint=model_name, + ) + router = MilesRouter(router_args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend_url}) + + try: + yield port + finally: + server.stop() + + +@pytest.fixture +def generation_env(request, variant): + SingletonMeta.clear_all_instances() + params = getattr(request, "param", {}) + args_kwargs = params.get("args_kwargs", {}) + model_name = args_kwargs.get("model_name", MODEL_NAME) + custom_generate_function_path = VARIANT_TO_GENERATE_FN_PATH[variant] + + def process_fn(_): + x = params.get("process_fn_kwargs", {}) + return ProcessResult( + text=x.get("response_text", RESPONSE_TEXT), + finish_reason=x.get("finish_reason", "stop"), + cached_tokens=x.get("cached_tokens", 0), + meta_info=ProcessResultMetaInfo( + weight_version=x.get("weight_version"), + routed_experts=x.get("routed_experts"), + spec_accept_token_num=x.get("spec_accept_token_num"), + spec_draft_token_num=x.get("spec_draft_token_num"), + spec_verify_ct=x.get("spec_verify_ct"), + ), + ) + + with with_mock_server(model_name=model_name, process_fn=process_fn) as mock_server: + with with_miles_router(mock_server.url, model_name) as router_port: + other_args_kwargs = {k: v for k, v in args_kwargs.items() if k != "model_name"} + args = make_args( + variant=variant, + router_port=router_port, + model_name=model_name, + custom_generate_function_path=custom_generate_function_path, + **other_args_kwargs, + ) + yield GenerateEnv(args=args, mock_server=mock_server) + + SingletonMeta.clear_all_instances() diff --git a/tests/fast/fixtures/rollout_fixtures.py b/tests/fast/fixtures/rollout_fixtures.py new file mode 100644 index 000000000..44d8a50d7 --- /dev/null +++ b/tests/fast/fixtures/rollout_fixtures.py @@ -0,0 +1,127 @@ +""" +Fixtures to test rollout-function +""" + +import json +from argparse import Namespace +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from unittest.mock import patch + +import pytest +import requests + +from miles.rollout.data_source import DataSource, RolloutDataSourceWithBuffer +from miles.router.router import MilesRouter +from miles.utils.arguments import parse_args +from miles.utils.http_utils import find_available_port, init_http_client +from miles.utils.misc import SingletonMeta +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +@dataclass(frozen=True) +class RolloutEnvConfig: + extra_argv: list[str] | None = None + data_rows: list[dict] | None = None + latency: float = 0.0 + + +@dataclass(frozen=True) +class RolloutEnv: + args: Namespace + data_source: DataSource + mock_server: MockSGLangServer + + +def _build_args(*, data_path: str, router_port: int, extra_argv: list[str] | None = None) -> Namespace: + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "1", + "--rollout-num-gpus-per-engine", + "1", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + data_path, + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--eval-prompt-data", + "toy", + data_path, + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + str(router_port), + "--rollout-max-response-len", + "16", + ] + (extra_argv or []) + with patch("sys.argv", argv): + args = parse_args() + args.miles_router_middleware_paths = [] + init_http_client(args) + return args + + +@contextmanager +def _with_miles_router(args: Namespace) -> Iterator[UvicornThreadServer]: + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + try: + server.start() + yield server + finally: + server.stop() + + +def _write_jsonl(path: str, rows: list[dict]) -> None: + Path(path).write_text("".join(json.dumps(row, ensure_ascii=False) + "\n" for row in rows), encoding="utf-8") + + +DEFAULT_DATA_ROWS = [{"input": "What is 1+7?", "label": "8"}] + + +@pytest.fixture +def rollout_env(tmp_path, request) -> RolloutEnv: + config = request.param + assert isinstance(config, RolloutEnvConfig) + + data_rows = config.data_rows or DEFAULT_DATA_ROWS + + data_path = str(tmp_path / "data.jsonl") + _write_jsonl(data_path, data_rows) + + router_port = find_available_port(20000) + args = _build_args(data_path=data_path, router_port=router_port, extra_argv=config.extra_argv) + + SingletonMeta.clear_all_instances() + + with with_mock_server(model_name=args.hf_checkpoint, latency=config.latency) as mock_server: + with _with_miles_router(args) as router_server: + r = requests.post( + f"{router_server.url}/add_worker", + params={"url": mock_server.url}, + timeout=5.0, + ) + r.raise_for_status() + + data_source = RolloutDataSourceWithBuffer(args) + yield RolloutEnv(args=args, data_source=data_source, mock_server=mock_server) + + SingletonMeta.clear_all_instances() diff --git a/tests/fast/rollout/__init__.py b/tests/fast/rollout/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/generate_hub/__init__.py b/tests/fast/rollout/generate_hub/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/generate_hub/test_multi_turn.py b/tests/fast/rollout/generate_hub/test_multi_turn.py new file mode 100644 index 000000000..c3ef3e855 --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_multi_turn.py @@ -0,0 +1,656 @@ +from copy import deepcopy +from dataclasses import dataclass, replace +from itertools import groupby + +import numpy as np +import pybase64 +import pytest +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from transformers import AutoTokenizer + +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, ThinkingThreeTurnStub, ThreeTurnStub, TwoTurnStub +from miles.utils.types import Sample + +_ = generation_env, SAMPLE_TOOLS, TwoTurnStub, ThreeTurnStub + + +def is_agentic_variant(variant: str) -> bool: + return variant in ("agentic_tool_call_single_sample", "agentic_tool_call_multi_samples") + + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +DEFAULT_SAMPLING_PARAMS = {"max_new_tokens": 64, "temperature": 0.7} +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + + +@pytest.fixture( + params=[ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", + ] +) +def variant(request): + return request.param + + +@dataclass(frozen=True) +class SampleParsedChunk: + tokens_decoded_str: str + loss_mask_value: int + rollout_log_probs: list[float] + + +@dataclass +class ExpectedSampleInfo: + chunks: list[SampleParsedChunk] + partial_sample: Sample + + +def token_len(text: str) -> int: + return len(TOKENIZER(text, add_special_tokens=False)["input_ids"]) + + +def expected_chunk(text: str, loss_mask: int) -> SampleParsedChunk: + n = token_len(text) + log_probs = [-1 / 128 * i for i in range(n)] if loss_mask else [0.0] * n + return SampleParsedChunk(text, loss_mask, log_probs) + + +def parse_sample_into_chunks(sample: Sample, tokenizer) -> list[SampleParsedChunk]: + prompt_len = len(sample.tokens) - sample.response_length + response_tokens = sample.tokens[prompt_len:] + loss_mask = sample.loss_mask or [] + log_probs = sample.rollout_log_probs or [] + + chunks = [] + idx = 0 + for mask_val, group in groupby(loss_mask): + group_len = len(list(group)) + sli = slice(idx, idx + group_len) + chunks.append( + SampleParsedChunk( + tokens_decoded_str=tokenizer.decode(response_tokens[sli]), + loss_mask_value=mask_val, + rollout_log_probs=log_probs[sli], + ) + ) + idx += group_len + return chunks + + +def expected_partial_sample( + *, + prompt: list[dict], + response: str, + response_length: int, + status: Sample.Status = Sample.Status.COMPLETED, +) -> Sample: + return Sample( + prompt=prompt, + response=response, + response_length=response_length, + status=status, + tokens=[], + loss_mask=[], + rollout_log_probs=[], + weight_versions=[], + spec_info=Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + + +def verify_samples(actual: Sample | list[Sample], expected: list[ExpectedSampleInfo]): + actual = listify(actual) + assert len(actual) == len(expected) + + for actual_item, expected_item in zip(actual, expected, strict=True): + actual_chunks = parse_sample_into_chunks(actual_item, TOKENIZER) + assert actual_chunks == expected_item.chunks + + actual_partial = replace( + deepcopy(actual_item), + tokens=[], + loss_mask=[], + rollout_log_probs=[], + prefix_cache_info=Sample.PrefixCacheInfo(), + ) + assert actual_partial == expected_item.partial_sample + + +def _run_generate(variant: str, env: GenerateEnv, sample: Sample, sampling_params: dict | None = None): + return run_generate(env, sample, sampling_params, variant=variant) + + +def expected_request(input_ids: list[int], sampling_params: dict | None = None) -> dict: + return { + "input_ids": input_ids, + "sampling_params": sampling_params or DEFAULT_SAMPLING_PARAMS, + "return_logprob": True, + "return_routed_experts": False, + } + + +def expected_openai_request(messages: list[dict]) -> dict: + input_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + return {"messages": messages, "model": "default", "tools": SAMPLE_TOOLS, "input_ids": input_ids} + + +SINGLE_TURN_PROMPT = [{"role": "user", "content": "What is 1+1?"}] +SINGLE_TURN_RESPONSE = "The answer is 2." +_SINGLE_TURN_PROMPT_TEXT = TOKENIZER.apply_chat_template( + SINGLE_TURN_PROMPT, tokenize=False, add_generation_prompt=True, tools=SAMPLE_TOOLS +) +SINGLE_TURN_PROMPT_TOKEN_IDS = TOKENIZER(_SINGLE_TURN_PROMPT_TEXT, add_special_tokens=False)["input_ids"] +SINGLE_TURN_PROMPT_TOKEN_LEN = len(SINGLE_TURN_PROMPT_TOKEN_IDS) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicMultiTurn: + def test_single_turn_no_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="stop" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(SINGLE_TURN_PROMPT)] + else: + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response=SINGLE_TURN_RESPONSE, response_length=6 + ), + ), + ], + ) + + def test_two_turns_with_tool_call(self, variant, generation_env): + generation_env.mock_server.process_fn = TwoTurnStub.process_fn + + S = TwoTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + S.SECOND_RESPONSE + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestExitConditions: + def test_partial_rollout_not_supported(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("agentic_tool_call does not check partial_rollout flag") + generation_env.args.partial_rollout = True + + with pytest.raises(AssertionError, match="Partial rollout is not supported"): + _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + def test_abort_preserves_content(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("agentic_tool_call does not handle abort finish_reason") + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=SINGLE_TURN_RESPONSE, finish_reason="abort" + ) + + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + + assert result.requests == [expected_request(SINGLE_TURN_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[ + SampleParsedChunk( + tokens_decoded_str=SINGLE_TURN_RESPONSE, + loss_mask_value=1, + rollout_log_probs=[-1 / 128 * i for i in range(6)], + ), + ], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, + response=SINGLE_TURN_RESPONSE, + response_length=6, + status=Sample.Status.ABORTED, + ), + ), + ], + ) + + def test_finish_reason_length_exits_and_preserves_content(self, variant, generation_env): + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="length") + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + verify_samples( + result.sample, + [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + status=Sample.Status.TRUNCATED, + ), + ), + ], + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"generate_max_turns": 1}}], indirect=True) + def test_max_turns_reached(self, variant, generation_env): + S = TwoTurnStub + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=S.FIRST_RESPONSE, finish_reason="stop") + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN)] + else: + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestRespectMaxContextLen: + @pytest.mark.parametrize( + "generation_env", [{"args_kwargs": {"rollout_max_context_len": SINGLE_TURN_PROMPT_TOKEN_LEN}}], indirect=True + ) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + result = _run_generate(variant, generation_env, make_sample(prompt=SINGLE_TURN_PROMPT)) + assert result.requests == [] + if variant == "multi_turn_single_sample": + expected = [ + ExpectedSampleInfo( + chunks=[], + partial_sample=expected_partial_sample( + prompt=SINGLE_TURN_PROMPT, response="", response_length=0, status=Sample.Status.TRUNCATED + ), + ) + ] + else: + expected = [] + verify_samples(result.sample, expected) + + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": { + "rollout_max_context_len": len(TwoTurnStub.FIRST_PROMPT_TOKEN_IDS) + + token_len(TwoTurnStub.FIRST_RESPONSE) + + token_len(TwoTurnStub.FIRST_TOOL_RESPONSE) + } + } + ], + indirect=True, + ) + def test_second_turn_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + assert result.requests == [expected_request(S.FIRST_PROMPT_TOKEN_IDS)] + if variant == "multi_turn_single_sample": + partial_response = S.FIRST_RESPONSE + S.FIRST_TOOL_RESPONSE + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=partial_response, + response_length=token_len(partial_response), + status=Sample.Status.TRUNCATED, + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + status=Sample.Status.TRUNCATED, + ), + ), + ] + verify_samples(result.sample, expected) + + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 10}}, + 10, + ), + ( + {"args_kwargs": {"rollout_max_context_len": len(TwoTurnStub.SECOND_PROMPT_TOKEN_IDS) + 100}}, + 64, + ), + ], + indirect=["generation_env"], + ) + def test_second_turn_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + S = TwoTurnStub + generation_env.mock_server.process_fn = S.process_fn + + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + assert len(result.requests) >= 2 + assert result.requests[1]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[1]["sampling_params"]["temperature"] == DEFAULT_SAMPLING_PARAMS["temperature"] + + +class TestThreeTurn: + """Need to test 3-turn case besides 2-turn, because e.g. merge_samples may behave differently.""" + + def test_three_turns_with_sequential_tool_calls(self, variant, generation_env): + generation_env.mock_server.process_fn = ThreeTurnStub.process_fn + + S = ThreeTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + expected_openai_request(S.OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + expected_request(S.THIRD_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = ( + S.FIRST_RESPONSE + + S.FIRST_TOOL_RESPONSE + + S.SECOND_RESPONSE + + S.SECOND_TOOL_RESPONSE + + S.THIRD_RESPONSE + ) + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + expected_chunk(S.SECOND_TOOL_RESPONSE, 0), + expected_chunk(S.THIRD_RESPONSE, 1), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.THIRD_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.THIRD_RESPONSE, + response_length=token_len(S.THIRD_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + +class TestFourTurnWithThink: + def test_four_turns_with_think_prefix(self, variant, generation_env): + generation_env.mock_server.process_fn = ThinkingThreeTurnStub.process_fn + + S = ThinkingThreeTurnStub + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT)) + + if is_agentic_variant(variant): + assert result.requests == [ + expected_openai_request(S.OPENAI_MESSAGES_FIRST_TURN), + expected_openai_request(S.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT), + expected_openai_request(S.OPENAI_MESSAGES_THIRD_TURN_FROM_CLIENT), + ] + else: + assert result.requests == [ + expected_request(S.FIRST_PROMPT_TOKEN_IDS), + expected_request(S.SECOND_PROMPT_TOKEN_IDS), + expected_request(S.THIRD_PROMPT_TOKEN_IDS), + ] + if variant in ("multi_turn_single_sample", "agentic_tool_call_single_sample"): + full_response = ( + S.FIRST_RESPONSE + + S.FIRST_TOOL_RESPONSE + + S.SECOND_RESPONSE + + S.SECOND_TOOL_RESPONSE + + S.THIRD_RESPONSE + ) + expected = [ + ExpectedSampleInfo( + chunks=[ + expected_chunk(S.FIRST_RESPONSE, 1), + expected_chunk(S.FIRST_TOOL_RESPONSE, 0), + expected_chunk(S.SECOND_RESPONSE, 1), + expected_chunk(S.SECOND_TOOL_RESPONSE, 0), + expected_chunk(S.THIRD_RESPONSE, 1), + ], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=full_response, + response_length=token_len(full_response), + ), + ), + ] + else: + expected = [ + ExpectedSampleInfo( + chunks=[expected_chunk(S.FIRST_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.FIRST_RESPONSE, + response_length=token_len(S.FIRST_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.SECOND_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.SECOND_RESPONSE, + response_length=token_len(S.SECOND_RESPONSE), + ), + ), + ExpectedSampleInfo( + chunks=[expected_chunk(S.THIRD_RESPONSE, 1)], + partial_sample=expected_partial_sample( + prompt=S.PROMPT, + response=S.THIRD_RESPONSE, + response_length=token_len(S.THIRD_RESPONSE), + ), + ), + ] + verify_samples(result.sample, expected) + + messages_without_think = deepcopy(S.OPENAI_MESSAGES_FOURTH_TURN_FROM_CLIENT) + messages_without_think[1]["content"] = messages_without_think[1]["content"].replace(S.THINK_PREFIX, "") + token_ids_with_think = TOKENIZER.apply_chat_template( + S.OPENAI_MESSAGES_FOURTH_TURN_FROM_CLIENT, tokenize=True, add_generation_prompt=True + ) + token_ids_without_think = TOKENIZER.apply_chat_template( + messages_without_think, tokenize=True, add_generation_prompt=True + ) + assert token_ids_with_think == token_ids_without_think + + +class TestRoutedExpertsMultiTurn: + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": { + "use_rollout_routing_replay": True, + } + } + ], + indirect=True, + ) + def test_two_turns_routed_experts(self, variant, generation_env): + if is_agentic_variant(variant): + pytest.skip("TODO: implement") + + S = TwoTurnStub + num_layers, moe_router_topk = 2, 4 + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk + + def make_routed_experts(prompt_token_ids, response_text): + total_tokens = len(prompt_token_ids) + token_len(response_text) + routed_experts_len = total_tokens - 1 + return np.arange(routed_experts_len * num_layers * moe_router_topk, dtype=np.int32).reshape( + routed_experts_len, num_layers, moe_router_topk + ) + + first_routed_experts = make_routed_experts(S.FIRST_PROMPT_TOKEN_IDS, S.FIRST_RESPONSE) + second_routed_experts = make_routed_experts(S.SECOND_PROMPT_TOKEN_IDS, S.SECOND_RESPONSE) + + def process_fn(prompt: str) -> ProcessResult: + if prompt == S.FIRST_PROMPT: + text, routed_experts = S.FIRST_RESPONSE, first_routed_experts + elif prompt == S.SECOND_PROMPT: + text, routed_experts = S.SECOND_RESPONSE, second_routed_experts + else: + raise ValueError(f"Unexpected prompt: {prompt}") + return ProcessResult( + text=text, + finish_reason="stop", + meta_info=ProcessResultMetaInfo( + routed_experts=pybase64.b64encode(routed_experts.tobytes()).decode("ascii") + ), + ) + + generation_env.mock_server.process_fn = process_fn + result = _run_generate(variant, generation_env, make_sample(prompt=S.PROMPT), DEFAULT_SAMPLING_PARAMS) + + sample = result.sample[-1] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == second_routed_experts.shape + np.testing.assert_array_equal(sample.rollout_routed_experts, second_routed_experts) + assert len(sample.tokens) - 1 == second_routed_experts.shape[0] diff --git a/tests/fast/rollout/generate_hub/test_single_turn.py b/tests/fast/rollout/generate_hub/test_single_turn.py new file mode 100644 index 000000000..a58e6fb3c --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_single_turn.py @@ -0,0 +1,424 @@ +import numpy as np +import pybase64 +import pytest +import torch +from PIL import Image +from tests.fast.fixtures.generation_fixtures import GenerateEnv, generation_env, listify, make_sample, run_generate +from transformers import AutoProcessor + +from miles.utils.processing_utils import encode_image_for_rollout_engine +from miles.utils.test_utils.mock_sglang_server import ProcessResult, ProcessResultMetaInfo +from miles.utils.types import Sample + +_ = generation_env + +# ------------------------------------ fixtures and consts ---------------------------------------- + + +MODEL_NAME = "Qwen/Qwen3-0.6B" +PROMPT = "What is 1+7?" +PROMPT_TOKENS = [3838, 374, 220, 16, 10, 22, 30] +PROMPT_TOKEN_LEN = len(PROMPT_TOKENS) +RESPONSE_TOKENS = [59, 79075, 90, 23, 92] +RESPONSE_TEXT = "\\boxed{8}" +RESPONSE_LOG_PROBS = [-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125] +SAMPLING_PARAMS = {"max_new_tokens": 16, "temperature": 0.7} +DEFAULT_MAX_NEW_TOKENS = SAMPLING_PARAMS["max_new_tokens"] + + +@pytest.fixture(params=["old_sglang_rollout", "single_turn", "multi_turn_single_sample", "multi_turn_multi_samples"]) +def variant(request): + return request.param + + +def expected_request( + variant: str, + *, + input_ids: list[int] | None = None, + sampling_params: dict | None = None, + return_routed_experts: bool = False, + image_data: list[str] | None = None, +) -> dict: + result = { + "input_ids": input_ids or PROMPT_TOKENS, + "sampling_params": sampling_params or SAMPLING_PARAMS, + "return_logprob": True, + } + if variant in ("single_turn", "multi_turn_single_sample", "multi_turn_multi_samples") or return_routed_experts: + result["return_routed_experts"] = return_routed_experts + if image_data is not None: + result["image_data"] = image_data + return result + + +class _Unset: + pass + + +_UNSET = _Unset() + + +def expected_sample( + variant: str, + *, + prompt: str = PROMPT, + response: str = RESPONSE_TEXT, + response_length: int = 5, + tokens: list[int] | None | _Unset = _UNSET, + rollout_log_probs: list[float] | None | _Unset = _UNSET, + status: Sample.Status = Sample.Status.COMPLETED, + cached_tokens: int = 0, + prompt_tokens: int = 7, + weight_versions: list[str] | None = None, + rollout_routed_experts: np.ndarray | None = None, + spec_info: Sample.SpecInfo | None = None, + multimodal_inputs: dict | None = None, + multimodal_train_inputs: dict | None = None, + loss_mask: list[int] | None | _Unset = _UNSET, +) -> Sample: + actual_response_length = response_length if response_length is not None else len(RESPONSE_TOKENS) + if isinstance(loss_mask, _Unset): + loss_mask = ( + [1] * actual_response_length + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") + else None + ) + + return Sample( + group_index=None, + index=None, + prompt=prompt, + tokens=PROMPT_TOKENS + RESPONSE_TOKENS if isinstance(tokens, _Unset) else tokens, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=multimodal_train_inputs, + response=response, + response_length=response_length, + label=None, + reward=None, + loss_mask=loss_mask, + weight_versions=weight_versions or [], + rollout_log_probs=RESPONSE_LOG_PROBS if isinstance(rollout_log_probs, _Unset) else rollout_log_probs, + rollout_routed_experts=rollout_routed_experts, + remove_sample=False, + status=status, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=spec_info or Sample.SpecInfo(), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=cached_tokens, total_prompt_tokens=prompt_tokens), + ) + + +def _make_sample(tokens=None, response="", response_length=0, status=Sample.Status.PENDING, multimodal_inputs=None): + return make_sample( + prompt=PROMPT, + tokens=tokens, + response=response, + response_length=response_length, + status=status, + multimodal_inputs=multimodal_inputs, + ) + + +def _run_generate(variant: str, env: GenerateEnv, sample: Sample | None = None, sampling_params: dict | None = None): + return run_generate(env, sample or _make_sample(), sampling_params or SAMPLING_PARAMS, variant=variant) + + +# ------------------------------------ tests ---------------------------------------- + + +class TestBasicGeneration: + def test_basic_generation(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant)] + + +class TestResumedSingleTurn: + def test_two_consecutive_calls_on_same_sample(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + partial_text = "\\boxed" + partial_tokens = [59, 79075] + partial_log_probs = [-0.0, -0.0078125] + + remaining_text = "{8}" + remaining_tokens = [90, 23, 92] + remaining_log_probs = [-0.0, -0.0078125, -0.015625] + + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=partial_text, finish_reason="abort") + sample = _make_sample() + result1 = _run_generate(variant, generation_env, sample) + assert result1.requests == [expected_request(variant)] + assert result1.sample == expected_sample( + variant, + response=partial_text, + response_length=2, + tokens=PROMPT_TOKENS + partial_tokens, + rollout_log_probs=partial_log_probs, + status=Sample.Status.ABORTED, + ) + + generation_env.mock_server.process_fn = lambda _: ProcessResult(text=remaining_text, finish_reason="stop") + result2 = _run_generate(variant, generation_env, result1.sample) + tokens_after_turn1 = PROMPT_TOKENS + partial_tokens + assert result2.requests == [ + expected_request( + variant, + input_ids=tokens_after_turn1, + sampling_params={"max_new_tokens": 14, "temperature": 0.7}, + ) + ] + assert result2.sample == expected_sample( + variant, + response=partial_text + remaining_text, + response_length=2 + 3, + tokens=tokens_after_turn1 + remaining_tokens, + rollout_log_probs=partial_log_probs + remaining_log_probs, + prompt_tokens=len(PROMPT_TOKENS) + len(tokens_after_turn1), + status=Sample.Status.COMPLETED, + ) + + +class TestFinishReason: + @pytest.mark.parametrize( + "generation_env,expected_status", + [ + ({"process_fn_kwargs": {"finish_reason": "stop"}}, Sample.Status.COMPLETED), + ({"process_fn_kwargs": {"finish_reason": "length"}}, Sample.Status.TRUNCATED), + ({"process_fn_kwargs": {"finish_reason": "abort"}}, Sample.Status.ABORTED), + ], + indirect=["generation_env"], + ) + def test_finish_reason_sets_status(self, variant, generation_env, expected_status): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant, status=expected_status)] + + +class TestRoutedExperts: + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": {"use_rollout_routing_replay": True}, + "process_fn_kwargs": {"routed_experts": "placeholder"}, + } + ], + indirect=True, + ) + def test_routed_experts_enabled_and_parsed(self, variant, generation_env): + num_layers, moe_router_topk = 2, 4 + num_tokens = len(PROMPT_TOKENS) + len(RESPONSE_TOKENS) + routed_experts_array = np.arange((num_tokens - 1) * num_layers * moe_router_topk, dtype=np.int32).reshape( + num_tokens - 1, num_layers, moe_router_topk + ) + + generation_env.args.num_layers = num_layers + generation_env.args.moe_router_topk = moe_router_topk + routed_experts_str = pybase64.b64encode(routed_experts_array.tobytes()).decode("ascii") + generation_env.mock_server.process_fn = lambda _: ProcessResult( + text=RESPONSE_TEXT, + finish_reason="stop", + meta_info=ProcessResultMetaInfo(routed_experts=routed_experts_str), + ) + + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant, return_routed_experts=True)] + sample = result.sample[0] if isinstance(result.sample, list) else result.sample + assert sample.rollout_routed_experts is not None + assert sample.rollout_routed_experts.shape == (num_tokens - 1, num_layers, moe_router_topk) + np.testing.assert_array_equal(sample.rollout_routed_experts, routed_experts_array) + + +class TestMetaInfo: + @pytest.mark.parametrize( + "generation_env", [{"process_fn_kwargs": {"cached_tokens": 3, "weight_version": "v1.0"}}], indirect=True + ) + def test_meta_info_fields_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant, cached_tokens=3, weight_versions=["v1.0"])] + + @pytest.mark.parametrize( + "generation_env", + [ + { + "args_kwargs": {"sglang_speculative_algorithm": "EAGLE"}, + "process_fn_kwargs": {"spec_accept_token_num": 10, "spec_draft_token_num": 15, "spec_verify_ct": 3}, + } + ], + indirect=True, + ) + def test_spec_info_updated(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [ + expected_sample( + variant, + spec_info=Sample.SpecInfo( + spec_accept_token_num=10, spec_draft_token_num=15, spec_verify_ct=3, completion_token_num=5 + ), + ) + ] + + +class TestInputStatusValidation: + @pytest.mark.parametrize("status", [Sample.Status.PENDING, Sample.Status.ABORTED]) + def test_allowed_statuses(self, variant, generation_env, status): + result = _run_generate(variant, generation_env, _make_sample(status=status)) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [expected_sample(variant)] + + @pytest.mark.parametrize("status", [Sample.Status.COMPLETED, Sample.Status.TRUNCATED]) + def test_rejected_statuses(self, variant, generation_env, status): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + with pytest.raises(AssertionError): + _run_generate(variant, generation_env, _make_sample(status=status)) + + +class TestPayloadStructure: + def test_sampling_params_passed_through(self, variant, generation_env): + result = _run_generate( + variant, generation_env, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9} + ) + assert result.requests == [ + expected_request(variant, sampling_params={"max_new_tokens": 16, "temperature": 0.5, "top_p": 0.9}) + ] + assert listify(result.sample) == [expected_sample(variant)] + + +class TestBoundaryConditions: + def test_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + existing_tokens = [1, 2, 3, 4, 5, 6, 7] + list(range(100, 110)) + sample = _make_sample(tokens=existing_tokens, response="x" * 10, response_length=10) + + result = _run_generate(variant, generation_env, sample, {"max_new_tokens": 10, "temperature": 0.7}) + assert result.requests == [] + assert result.sample == expected_sample( + variant, + response="x" * 10, + response_length=10, + tokens=existing_tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + ) + + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"rollout_max_context_len": 5}}], indirect=True) + def test_prompt_exceeds_max_context_len_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant in ("multi_turn_single_sample", "multi_turn_multi_samples") else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] + + @pytest.mark.parametrize( + "generation_env,expected_max_new_tokens", + [ + ({"args_kwargs": {"rollout_max_context_len": 10}}, 10 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 8}}, 8 - PROMPT_TOKEN_LEN), + ({"args_kwargs": {"rollout_max_context_len": 100}}, DEFAULT_MAX_NEW_TOKENS), + ], + indirect=["generation_env"], + ) + def test_moderate_length_input_adjusts_max_new_tokens(self, variant, generation_env, expected_max_new_tokens): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + result = _run_generate(variant, generation_env) + assert len(result.requests) == 1 + assert result.requests[0]["sampling_params"]["max_new_tokens"] == expected_max_new_tokens + assert result.requests[0]["sampling_params"]["temperature"] == SAMPLING_PARAMS["temperature"] + assert listify(result.sample) == [expected_sample(variant)] + + @pytest.mark.parametrize( + "generation_env", + [{"args_kwargs": {"rollout_max_context_len": PROMPT_TOKEN_LEN}}], + indirect=True, + ) + def test_adjusted_max_new_tokens_zero_returns_truncated(self, variant, generation_env): + if variant == "old_sglang_rollout": + pytest.skip("old_sglang_rollout does not support rollout_max_context_len") + if variant == "multi_turn_multi_samples": + pytest.skip("multi_turn_multi_samples returns empty list when first turn fails") + result = _run_generate(variant, generation_env) + assert result.requests == [] + tokens = PROMPT_TOKENS if variant == "multi_turn_single_sample" else [] + assert listify(result.sample) == [ + expected_sample( + variant, + response="", + response_length=0, + tokens=tokens, + rollout_log_probs=None, + status=Sample.Status.TRUNCATED, + prompt_tokens=0, + loss_mask=None if variant == "multi_turn_single_sample" else _UNSET, + ) + ] + + +class TestEmptyResponse: + @pytest.mark.parametrize("generation_env", [{"process_fn_kwargs": {"response_text": ""}}], indirect=True) + def test_empty_response(self, variant, generation_env): + result = _run_generate(variant, generation_env) + assert result.requests == [expected_request(variant)] + assert listify(result.sample) == [ + expected_sample(variant, response="", response_length=0, tokens=PROMPT_TOKENS, rollout_log_probs=[]) + ] + + +VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" + + +class TestMultimodal: + @pytest.mark.parametrize("generation_env", [{"args_kwargs": {"model_name": VLM_MODEL_NAME}}], indirect=True) + def test_multimodal_inputs_processed(self, variant, generation_env): + if variant in ("multi_turn_single_sample", "multi_turn_multi_samples"): + pytest.skip("not tested yet") + test_image = Image.new("RGB", (64, 64), color="red") + multimodal_inputs = {"images": [test_image]} + processor = AutoProcessor.from_pretrained(VLM_MODEL_NAME, trust_remote_code=True) + expected_mti = { + k: v + for k, v in processor(text=PROMPT, **multimodal_inputs).items() + if k not in ["input_ids", "attention_mask"] + } + + result = _run_generate(variant, generation_env, _make_sample(multimodal_inputs=multimodal_inputs)) + + assert result.requests == [ + expected_request( + variant, + input_ids=PROMPT_TOKENS, + image_data=[encode_image_for_rollout_engine(test_image)], + ) + ] + actual_mti = result.sample.multimodal_train_inputs + assert actual_mti is not None + assert set(actual_mti.keys()) == set(expected_mti.keys()) + assert torch.all(actual_mti["pixel_values"] == expected_mti["pixel_values"]) + assert torch.all(actual_mti["image_grid_thw"] == expected_mti["image_grid_thw"]) + assert result.sample == expected_sample( + variant, + tokens=PROMPT_TOKENS + RESPONSE_TOKENS, + multimodal_inputs=multimodal_inputs, + multimodal_train_inputs=actual_mti, + ) diff --git a/tests/fast/rollout/generate_hub/test_tool_call_utils.py b/tests/fast/rollout/generate_hub/test_tool_call_utils.py new file mode 100644 index 000000000..0f2305e75 --- /dev/null +++ b/tests/fast/rollout/generate_hub/test_tool_call_utils.py @@ -0,0 +1,99 @@ +import pytest + +from miles.rollout.generate_utils.tool_call_utils import _DUMMY_USER, _build_dummy_assistant, tokenize_tool_responses + +TOOL_CALL_TEST_MODELS = [ + "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen3-0.6B", + "Qwen/Qwen3-4B-Instruct-2507", + "Qwen/Qwen3-Coder-30B-A3B-Instruct", + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI + "mistralai/Mistral-7B-Instruct-v0.3", + "deepseek-ai/DeepSeek-V3", + "stepfun-ai/step3", + "MiniMaxAI/MiniMax-M2", + "internlm/internlm3-8b-instruct", + "THUDM/glm-4-9b-chat", + "moonshotai/Kimi-K2-Instruct", + "XiaomiMiMo/MiMo-7B-RL", +] + +SINGLE_TOOL_CALL_ONLY_MODELS = [ + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo +] + +# Models where tokenize->decode produces extra whitespace vs direct string diff +TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS = [ + "THUDM/glm-4-9b-chat", +] + +SAMPLE_TOOL_RESPONSES = [ + { + "role": "tool", + "tool_call_id": "call00000", + "content": '{"year": 2026}', + "name": "get_year", + }, + { + "role": "tool", + "tool_call_id": "call00001", + "content": '{"temperature": 25}', + "name": "get_temperature", + }, +] + + +class TestTokenizeToolResponses: + @pytest.mark.parametrize("model_name", ["Qwen/Qwen3-0.6B"]) + def test_snapshot(self, model_name): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + token_ids = tokenize_tool_responses(SAMPLE_TOOL_RESPONSES, tokenizer) + decoded = tokenizer.decode(token_ids) + + assert decoded == ( + "<|im_start|>user\n" + "\n" + '{"year": 2026}\n' + "\n" + "\n" + '{"temperature": 25}\n' + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + @pytest.mark.parametrize("num_tools", [1, 2]) + @pytest.mark.parametrize("model_name", TOOL_CALL_TEST_MODELS) + def test_tokenize_tool_responses(self, model_name, num_tools): + if num_tools > 1 and model_name in SINGLE_TOOL_CALL_ONLY_MODELS: + pytest.skip(f"{model_name} only supports single tool call") + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + tool_responses = SAMPLE_TOOL_RESPONSES[:num_tools] + assert len(tool_responses) == num_tools + + actual_token_ids = tokenize_tool_responses(tool_responses, tokenizer) + actual_str = tokenizer.decode(actual_token_ids) + + dummy_assistant = _build_dummy_assistant(tool_responses) + base_messages = [_DUMMY_USER, dummy_assistant] + expected_str = self._compute_chat_template_diff(base_messages, tool_responses, tokenizer) + + if model_name in TOKENIZE_DECODE_WHITESPACE_DIFF_MODELS: + # Some models produce whitespace differences between tokenize->decode and direct string diff + actual_str = actual_str.replace(" ", "") + expected_str = expected_str.replace(" ", "") + + assert actual_str == expected_str, f"{model_name=}" + + @staticmethod + def _compute_chat_template_diff(base_messages, extra_messages, tokenizer) -> str: + text_with = tokenizer.apply_chat_template( + base_messages + extra_messages, tokenize=False, add_generation_prompt=True + ) + text_without = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=False) + return text_with[len(text_without) :] diff --git a/tests/fast/rollout/generate_utils/__init__.py b/tests/fast/rollout/generate_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/generate_utils/test_sample_utils.py b/tests/fast/rollout/generate_utils/test_sample_utils.py new file mode 100644 index 000000000..c53fbbb56 --- /dev/null +++ b/tests/fast/rollout/generate_utils/test_sample_utils.py @@ -0,0 +1,156 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.generate_utils.sample_utils import _merge_sample_pair +from miles.utils.types import Sample + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + tokenizer.decode = lambda tokens: f"" + return tokenizer + + +def make_sample( + prompt="test_prompt", + tokens=None, + response="", + response_length=0, + loss_mask=None, + rollout_log_probs=None, + status=Sample.Status.COMPLETED, + label="test_label", + reward=1.0, + index=0, + group_index=0, +): + return Sample( + prompt=prompt, + tokens=tokens or [], + response=response, + response_length=response_length, + loss_mask=loss_mask, + rollout_log_probs=rollout_log_probs, + status=status, + label=label, + reward=reward, + index=index, + group_index=group_index, + ) + + +class TestMergeSamples: + def test_basic_merge(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3, 10, 11, 12], + response="response1", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.1, -0.2, -0.3], + ) + b = make_sample( + tokens=[1, 2, 3, 10, 11, 12, 20, 21, 30, 31, 32], + response="response2", + response_length=3, + loss_mask=[1, 1, 1], + rollout_log_probs=[-0.4, -0.5, -0.6], + status=Sample.Status.TRUNCATED, + ) + + merged = _merge_sample_pair(a, b, mock_tokenizer) + + assert merged.tokens == b.tokens + assert merged.response_length == 3 + 2 + 3 + assert merged.loss_mask == [1, 1, 1, 0, 0, 1, 1, 1] + assert merged.rollout_log_probs == [-0.1, -0.2, -0.3, 0.0, 0.0, -0.4, -0.5, -0.6] + assert merged.prompt == a.prompt + assert merged.status == b.status + assert merged.label == a.label + assert merged.index == a.index + assert merged.group_index == a.group_index + assert "response1" in merged.response + assert "response2" in merged.response + assert "" in merged.response + + def test_loss_mask_none_defaults_to_all_ones(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=None, + rollout_log_probs=None, + ) + + merged = _merge_sample_pair(a, b, mock_tokenizer) + + assert merged.loss_mask == [1, 0, 1] + assert merged.rollout_log_probs == [0.0, 0.0, 0.0] + + def test_tokens_prefix_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 3], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 99, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="b.tokens must start with a.tokens"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_field_mismatch_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + index=0, + ) + b = make_sample( + tokens=[1, 2, 10, 20, 30], + response_length=1, + loss_mask=[1], + index=1, + ) + + with pytest.raises(AssertionError, match="index mismatch"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_obs_len_invalid_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10], + response_length=1, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="obs_len must be > 0"): + _merge_sample_pair(a, b, mock_tokenizer) + + def test_sample_validate_fails_raises(self, mock_tokenizer): + a = make_sample( + tokens=[1, 2, 10, 11], + response_length=2, + loss_mask=[1], + ) + b = make_sample( + tokens=[1, 2, 10, 11, 20, 30], + response_length=1, + loss_mask=[1], + ) + + with pytest.raises(AssertionError, match="loss_mask length"): + _merge_sample_pair(a, b, mock_tokenizer) diff --git a/tests/fast/rollout/inference_rollout/__init__.py b/tests/fast/rollout/inference_rollout/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/inference_rollout/conftest.py b/tests/fast/rollout/inference_rollout/conftest.py new file mode 100644 index 000000000..ca47edeeb --- /dev/null +++ b/tests/fast/rollout/inference_rollout/conftest.py @@ -0,0 +1,45 @@ +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import parse_args + + +def _build_mock_args(extra_argv: list[str] | None = None): + argv = [ + "pytest", + "--train-backend", + "fsdp", + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "1", + "--num-rollout", + "1", + "--rollout-num-gpus", + "4", + "--rollout-num-gpus-per-engine", + "2", + "--hf-checkpoint", + "Qwen/Qwen3-0.6B", + "--prompt-data", + "/dev/null", + "--input-key", + "input", + "--label-key", + "label", + "--rm-type", + "math", + "--use-miles-router", + "--sglang-router-ip", + "127.0.0.1", + "--sglang-router-port", + "30000", + ] + (extra_argv or []) + with patch("sys.argv", argv): + return parse_args() + + +@pytest.fixture +def mock_args(): + return _build_mock_args() diff --git a/tests/fast/rollout/inference_rollout/integration/__init__.py b/tests/fast/rollout/inference_rollout/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/inference_rollout/integration/test_basic.py b/tests/fast/rollout/inference_rollout/integration/test_basic.py new file mode 100644 index 000000000..5b791829d --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_basic.py @@ -0,0 +1,69 @@ +import pytest +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import ( + MODULAR_ROLLOUT_BASE_ARGV, + expected_sample, + load_and_call_train, +) + +from miles.rollout.base_types import RolloutFnConstructorInput, RolloutFnEvalInput +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function + +_VARIANTS = [ + pytest.param( + RolloutEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--eval-function-path", + "miles.rollout.sglang_rollout.generate_rollout", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="old_rollout_old_generate", + ), + pytest.param( + RolloutEnvConfig( + extra_argv=[ + "--rollout-function-path", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", + "--custom-generate-function-path", + "miles.rollout.sglang_rollout.generate", + ] + ), + id="new_rollout_old_generate", + ), + pytest.param( + RolloutEnvConfig(extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant("single_turn")), + id="new_rollout_new_generate", + ), +] + + +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_train(rollout_env): + env = rollout_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert len(group) == env.args.n_samples_per_prompt + assert group[0] == expected_sample(group_index=0) + + +@pytest.mark.parametrize("rollout_env", _VARIANTS, indirect=True) +def test_eval(rollout_env): + env = rollout_env + fn = load_rollout_function( + RolloutFnConstructorInput(args=env.args, data_source=env.data_source), env.args.eval_function_path + ) + out = call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + assert "toy" in out.data + rewards = out.data["toy"]["rewards"] + samples = out.data["toy"]["samples"] + assert len(rewards) == len(samples) == env.args.n_samples_per_eval_prompt + assert rewards[0] == 1 + assert samples[0] == expected_sample(group_index=None) diff --git a/tests/fast/rollout/inference_rollout/integration/test_deterministic.py b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py new file mode 100644 index 000000000..69a235911 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_deterministic.py @@ -0,0 +1,37 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_env,expected_seeds", + [ + pytest.param( + integration_env_config( + [ + "--sglang-enable-deterministic-inference", + "--rollout-seed", + "42", + "--n-samples-per-prompt", + "3", + "--rollout-batch-size", + "1", + ] + ), + {42, 43, 44}, + id="enabled", + ), + pytest.param( + integration_env_config(["--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + {None}, + id="disabled", + ), + ], + indirect=["rollout_env"], +) +def test_sampling_seeds(rollout_env, expected_seeds): + env = rollout_env + load_and_call_train(env.args, env.data_source) + + seeds = {req.get("sampling_params", {}).get("sampling_seed") for req in env.mock_server.request_log} + assert seeds == expected_seeds diff --git a/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py new file mode 100644 index 000000000..0ca5743ac --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_dynamic_filter.py @@ -0,0 +1,46 @@ +from contextlib import nullcontext + +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + MIXED_DATA_ROWS, + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + + +@pytest.mark.parametrize( + "rollout_env,use_filter,expect_all_correct", + [ + pytest.param( + integration_env_config(["--rollout-batch-size", "4"], data_rows=MIXED_DATA_ROWS), + False, + False, + id="no_filter", + ), + pytest.param( + integration_env_config( + ["--rollout-batch-size", "3", "--dynamic-sampling-filter-path", "test:filter_by_reward"], + data_rows=MIXED_DATA_ROWS, + ), + True, + True, + id="with_filter", + ), + ], + indirect=["rollout_env"], +) +def test_filter_effect(rollout_env, use_filter, expect_all_correct): + env = rollout_env + ctx = function_registry.temporary("test:filter_by_reward", filter_by_reward) if use_filter else nullcontext() + + with ctx: + out = load_and_call_train(env.args, env.data_source) + + rewards = {group[0].reward for group in out.samples} + if expect_all_correct: + assert rewards == {1}, "Filter should keep only correct samples" + else: + assert 0 in rewards, "Without filter, incorrect samples should be present" diff --git a/tests/fast/rollout/inference_rollout/integration/test_group_rm.py b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py new file mode 100644 index 000000000..afd870c30 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_group_rm.py @@ -0,0 +1,22 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + integration_env_config(["--group-rm", "--n-samples-per-prompt", "2", "--rollout-batch-size", "1"]), + id="group_rm_enabled", + ), + ], + indirect=True, +) +def test_group_rm_rewards_set(rollout_env): + env = rollout_env + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + rewards = [sample.reward for group in out.samples for sample in group] + assert all(r in (0, 1) for r in rewards) diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py new file mode 100644 index 000000000..2b12d3d88 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_sample.py @@ -0,0 +1,65 @@ +import pytest +from tests.fast.fixtures.rollout_fixtures import DEFAULT_DATA_ROWS, RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_train + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.utils.misc import function_registry +from miles.utils.types import Sample + + +async def _multi_sample_generate(input: GenerateFnInput) -> GenerateFnOutput: + sample = input.sample + s1 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=None, + status=Sample.Status.COMPLETED, + ) + s2 = Sample( + prompt=sample.prompt, + response="\\boxed{8}", + response_length=5, + tokens=sample.tokens + [59, 79075, 90, 23, 92], + label=sample.label, + reward=0.5, + status=Sample.Status.COMPLETED, + ) + return GenerateFnOutput(samples=[s1, s2]) + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + + [ + "--custom-generate-function-path", + "test:multi_sample_generate", + "--rollout-batch-size", + "1", + "--n-samples-per-prompt", + "1", + ], + data_rows=DEFAULT_DATA_ROWS, + ), + id="multi_sample_output", + ), + ], + indirect=True, +) +def test_multi_sample_output_preserves_existing_reward(rollout_env): + env = rollout_env + with function_registry.temporary("test:multi_sample_generate", _multi_sample_generate): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + assert isinstance(group[0], list) + samples = group[0] + assert len(samples) == 2 + assert samples[0].reward == 1 + assert samples[1].reward == 0.5 diff --git a/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py new file mode 100644 index 000000000..c41d71399 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_multi_turn.py @@ -0,0 +1,114 @@ +from typing import Any + +import pytest +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig +from tests.fast.rollout.inference_rollout.integration.utils import MODULAR_ROLLOUT_BASE_ARGV, load_and_call_rollout + +from miles.utils.test_utils.mock_tools import TwoTurnStub +from miles.utils.types import Sample + + +TWO_TURN_DATA_ROWS = [{"input": [{"role": "user", "content": TwoTurnStub.USER_QUESTION}], "label": "2008"}] + +_VARIANT_NAMES = [ + "multi_turn_single_sample", + "multi_turn_multi_samples", + "agentic_tool_call_single_sample", + "agentic_tool_call_multi_samples", +] + +_BASE_EXTRA_ARGV = [ + "--rollout-batch-size", + "2", + "--n-samples-per-prompt", + "2", + "--n-samples-per-eval-prompt", + "2", + "--custom-rm-path", + "tests.fast.rollout.inference_rollout.integration.test_multi_turn._simple_reward_function", +] + + +def _config_for_variant(variant: str) -> RolloutEnvConfig: + return RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + _BASE_EXTRA_ARGV, + data_rows=TWO_TURN_DATA_ROWS, + ) + + +@pytest.mark.parametrize( + "variant,rollout_env", + [pytest.param(variant, _config_for_variant(variant), id=variant) for variant in _VARIANT_NAMES], + indirect=["rollout_env"], +) +@pytest.mark.parametrize("test_type", ["train", "eval"]) +def test_rollout(rollout_env, variant, test_type): + env = rollout_env + env.mock_server.process_fn = TwoTurnStub.process_fn + + out = load_and_call_rollout(env.args, env.data_source, mode=test_type) + + if test_type == "train": + assert len(out.samples) == env.args.rollout_batch_size + group = out.samples[0] + _verify_samples(variant, group) + else: + assert "toy" in out.data + samples = out.data["toy"]["samples"] + _verify_samples(variant, samples) + + +def _verify_samples(variant: str, samples: list[Any]): + is_multi_samples = variant in ("multi_turn_multi_samples", "agentic_tool_call_multi_samples") + + if is_multi_samples: + if len(samples) > 0 and isinstance(samples[0], list): + # Train mode: list[list[Sample]], grouped by prompt + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for group_sample in samples: + assert isinstance(group_sample, list), "multi_samples variant should return list[Sample] per generate" + _verify_group_samples(group_sample) + else: + # Eval mode: list[Sample], flattened + # n_samples_per_eval_prompt=2, and each generate returns 2 turns, so 2*2=4 samples + assert ( + len(samples) == 4 + ), f"n_samples_per_eval_prompt=2, each generate returns 2 turns, so should have 4 samples, got {len(samples)}" + # Group samples by prompt (every 2 samples form a group) + group_samples_list = [samples[i : i + 2] for i in range(0, len(samples), 2)] + for group_samples in group_samples_list: + _verify_group_samples(group_samples) + else: + assert len(samples) == 2, f"n_samples_per_prompt=2, so group should have 2 samples, got {len(samples)}" + for sample in samples: + assert isinstance(sample, Sample), "single_sample variant should return Sample, not list" + _verify_sample(sample) + + +def _verify_group_samples(group_samples: list[Sample], expected_count: int = 2): + assert len(group_samples) == expected_count, f"Group should have {expected_count} samples (one per turn)" + for i, sample in enumerate(group_samples): + _verify_sample(sample, expect_answer=(i == len(group_samples) - 1)) + + +def _verify_sample(sample: Sample, expected_reward: float = 1.0, expect_answer: bool = True): + assert sample.status == Sample.Status.COMPLETED + assert sample.reward == expected_reward, f"Sample should have reward={expected_reward}" + if expect_answer: + assert "2008" in sample.response, "Response should contain final answer '2008'" + + +async def _simple_reward_function(args, samples: Sample | list[Sample]) -> float | list[float]: + if isinstance(samples, list): + # For multi_samples variants, use the last sample's reward + if getattr(args, "generate_multi_samples", False): + return [_check_reward(samples[-1])] * len(samples) + else: + return [_check_reward(sample) for sample in samples] + else: + return _check_reward(samples) + + +def _check_reward(sample: Sample) -> float: + return float(sample.response and (str(sample.label) in sample.response)) diff --git a/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py new file mode 100644 index 000000000..0812962cc --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_over_sampling.py @@ -0,0 +1,48 @@ +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + +_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "wrong"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "wrong"}, +] + +_BASE_ARGV = [ + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", +] + + +def _over_sampling_config(rollout_batch_size: int): + return integration_env_config(["--rollout-batch-size", str(rollout_batch_size)] + _BASE_ARGV, data_rows=_DATA_ROWS) + + +@pytest.mark.parametrize( + "rollout_env,expected_rounds", + [ + pytest.param(_over_sampling_config(1), 1, id="one_round"), + pytest.param(_over_sampling_config(2), 2, id="two_rounds"), + ], + indirect=["rollout_env"], +) +def test_over_sampling_rounds(rollout_env, expected_rounds): + env = rollout_env + + with function_registry.temporary("test:filter_by_reward", filter_by_reward): + out = load_and_call_train(env.args, env.data_source) + + assert len(out.samples) == env.args.rollout_batch_size + assert all(group[0].reward == 1 for group in out.samples) + + requests_count = len(env.mock_server.request_log) + expected_requests = expected_rounds * env.args.over_sampling_batch_size + assert requests_count == expected_requests, f"Expected {expected_rounds} round(s) = {expected_requests} requests" diff --git a/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py new file mode 100644 index 000000000..36e78c16c --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_sample_filter.py @@ -0,0 +1,67 @@ +from unittest.mock import Mock + +import pytest +from tests.fast.rollout.inference_rollout.integration.utils import ( + filter_by_reward, + integration_env_config, + load_and_call_train, +) + +from miles.utils.misc import function_registry + +# Data with only 2 reward=1 samples out of 4. +# This ensures all 4 samples must be generated to collect 2 valid ones. +_FILTER_TEST_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, # reward=1 + {"input": "What is 1+8?", "label": "wrong"}, # reward=0 + {"input": "What is 1+9?", "label": "wrong"}, # reward=0 + {"input": "What is 1+6?", "label": "7"}, # reward=1 +] + + +@pytest.mark.parametrize( + "rollout_env", + [ + pytest.param( + integration_env_config( + [ + "--rollout-batch-size", + "2", + "--over-sampling-batch-size", + "4", + "--dynamic-sampling-filter-path", + "test:filter_by_reward", + "--rollout-sample-filter-path", + "test:sample_filter", + "--rollout-all-samples-process-path", + "test:all_samples_process", + ], + data_rows=_FILTER_TEST_DATA_ROWS, + ), + id="sample_filter_vs_all_samples", + ), + ], + indirect=True, +) +def test_sample_filter_and_all_samples_process(rollout_env): + env = rollout_env + sample_filter_mock = Mock() + all_samples_process_mock = Mock() + + with ( + function_registry.temporary("test:filter_by_reward", filter_by_reward), + function_registry.temporary("test:sample_filter", sample_filter_mock), + function_registry.temporary("test:all_samples_process", all_samples_process_mock), + ): + load_and_call_train(env.args, env.data_source) + + sample_filter_mock.assert_called_once() + _, filtered_data = sample_filter_mock.call_args[0] + rewards = [g[0][0].reward if isinstance(g[0], list) else g[0].reward for g in filtered_data] + assert all(r == 1 for r in rewards) + + all_samples_process_mock.assert_called_once() + _, all_samples, data_source = all_samples_process_mock.call_args[0] + assert data_source is not None + + assert len(all_samples) > len(filtered_data), "all_samples_process should see more samples than sample_filter" diff --git a/tests/fast/rollout/inference_rollout/integration/test_semaphore.py b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py new file mode 100644 index 000000000..889a9ff8a --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/test_semaphore.py @@ -0,0 +1,33 @@ +import pytest + +from tests.fast.rollout.inference_rollout.integration.utils import integration_env_config, load_and_call_train + +_DATA_ROWS = [{"input": f"What is 1+{i}?", "label": str(1 + i)} for i in range(10)] +_BASE_ARGV = ["--rollout-batch-size", "4", "--n-samples-per-prompt", "2"] + + +@pytest.mark.parametrize( + "rollout_env,expected_range", + [ + pytest.param( + integration_env_config( + ["--sglang-server-concurrency", "1"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), + (1, 1), + id="limit_1", + ), + pytest.param( + integration_env_config( + ["--sglang-server-concurrency", "999"] + _BASE_ARGV, data_rows=_DATA_ROWS, latency=0.05 + ), + (2, 999), + id="no_limit", + ), + ], + indirect=["rollout_env"], +) +def test_max_concurrent(rollout_env, expected_range): + env = rollout_env + load_and_call_train(env.args, env.data_source) + min_expected, max_expected = expected_range + assert min_expected <= env.mock_server.max_concurrent <= max_expected diff --git a/tests/fast/rollout/inference_rollout/integration/utils.py b/tests/fast/rollout/inference_rollout/integration/utils.py new file mode 100644 index 000000000..ad413cf94 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/integration/utils.py @@ -0,0 +1,89 @@ +from tests.fast.fixtures.generation_fixtures import extra_argv_for_variant +from tests.fast.fixtures.rollout_fixtures import RolloutEnvConfig + +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnOutput, + RolloutFnTrainInput, +) +from miles.rollout.filter_hub.base_types import DynamicFilterOutput +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.utils.types import Sample + + +def expected_sample(*, group_index: int | None) -> Sample: + return Sample( + group_index=group_index, + index=0, + prompt="What is 1+7?", + tokens=[3838, 374, 220, 16, 10, 22, 30, 59, 79075, 90, 23, 92], + multimodal_inputs=None, + multimodal_train_inputs=None, + response="\\boxed{8}", + response_length=5, + label="8", + reward=1, + loss_mask=None, + weight_versions=[], + rollout_log_probs=[-0.0, -0.0078125, -0.015625, -0.0234375, -0.03125], + rollout_routed_experts=None, + remove_sample=False, + status=Sample.Status.COMPLETED, + metadata={}, + train_metadata=None, + non_generation_time=0.0, + spec_info=Sample.SpecInfo( + spec_accept_token_num=0, spec_draft_token_num=0, spec_verify_ct=0, completion_token_num=0 + ), + prefix_cache_info=Sample.PrefixCacheInfo(cached_tokens=0, total_prompt_tokens=7), + ) + + +MODULAR_ROLLOUT_BASE_ARGV = [ + "--rollout-function-path", + "miles.rollout.inference_rollout.inference_rollout_common.InferenceRolloutFn", +] + +MIXED_DATA_ROWS = [ + {"input": "What is 1+7?", "label": "8"}, + {"input": "What is 1+8?", "label": "9"}, + {"input": "What is 1+9?", "label": "wrong"}, + {"input": "What is 1+6?", "label": "7"}, +] + + +def integration_env_config( + extra_argv: list[str], + data_rows: list[dict] | None = None, + latency: float = 0.0, + variant: str = "single_turn", +): + return RolloutEnvConfig( + extra_argv=MODULAR_ROLLOUT_BASE_ARGV + extra_argv_for_variant(variant) + extra_argv, + data_rows=data_rows, + latency=latency, + ) + + +def load_and_call_rollout(args, data_source, mode: str = "train") -> RolloutFnOutput: + function_path = args.rollout_function_path if mode == "train" else args.eval_function_path + fn = load_rollout_function( + RolloutFnConstructorInput(args=args, data_source=data_source), + function_path, + ) + if mode == "train": + return call_rollout_function(fn, RolloutFnTrainInput(rollout_id=0)) + else: + return call_rollout_function(fn, RolloutFnEvalInput(rollout_id=0)) + + +def load_and_call_train(args, data_source): + return load_and_call_rollout(args, data_source, mode="train") + + +def filter_by_reward(args, samples, **kwargs): + reward = samples[0].reward if not isinstance(samples[0], list) else samples[0][0].reward + if reward == 1: + return DynamicFilterOutput(keep=True) + return DynamicFilterOutput(keep=False, reason="reward_zero") diff --git a/tests/fast/rollout/inference_rollout/test_compatibility.py b/tests/fast/rollout/inference_rollout/test_compatibility.py new file mode 100644 index 000000000..ddfecd067 --- /dev/null +++ b/tests/fast/rollout/inference_rollout/test_compatibility.py @@ -0,0 +1,196 @@ +import asyncio +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.base_types import ( + GenerateFnInput, + GenerateFnOutput, + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnEvalOutput, + RolloutFnTrainInput, + RolloutFnTrainOutput, +) +from miles.rollout.inference_rollout.compatibility import ( + LegacyGenerateFnAdapter, + LegacyRolloutFnAdapter, + call_rollout_function, + load_generate_function, + load_rollout_function, +) +from miles.utils.async_utils import run +from miles.utils.misc import function_registry + + +@pytest.fixture +def constructor_input(): + return RolloutFnConstructorInput(args="dummy_args", data_source="dummy_data_source") + + +@pytest.fixture +def make_generate_fn_input(): + def _make(evaluation: bool = False): + state = MagicMock() + state.args = MagicMock() + + return GenerateFnInput( + state=state, + sample={"text": "test prompt"}, + sampling_params={"temperature": 0.7}, + evaluation=evaluation, + ) + + return _make + + +class TestSupportedRolloutFormats: + """ + Documentation test to show various supported rollout function formats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_raw_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return {"metric": {"accuracy": 0.9}} + return [[{"text": "sample"}]] + + with function_registry.temporary("test:legacy_rollout", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_rollout") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, LegacyRolloutFnAdapter) + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"metric": {"accuracy": 0.9}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "sample"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_typed_output(self, constructor_input, evaluation): + def legacy_rollout_fn(args, rollout_id, data_source, evaluation=False): + if evaluation: + return RolloutFnEvalOutput(data={"ds": {"acc": 0.95}}) + return RolloutFnTrainOutput(samples=[[{"text": "typed"}]]) + + with function_registry.temporary("test:legacy_typed", legacy_rollout_fn): + fn = load_rollout_function(constructor_input, "test:legacy_typed") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + if evaluation: + assert isinstance(result, RolloutFnEvalOutput) + assert result.data == {"ds": {"acc": 0.95}} + else: + assert isinstance(result, RolloutFnTrainOutput) + assert result.samples == [[{"text": "typed"}]] + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_sync_class(self, constructor_input, evaluation): + class SyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + def __call__(self, input): + if input.evaluation: + return RolloutFnEvalOutput(data={"test": {"score": 1}}) + return RolloutFnTrainOutput(samples=[[{"text": "sync"}]]) + + with function_registry.temporary("test:sync_class", SyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:sync_class") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, SyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_async_class(self, constructor_input, evaluation): + class AsyncRolloutFn: + def __init__(self, input: RolloutFnConstructorInput): + pass + + async def __call__(self, input): + await asyncio.sleep(0.001) + if input.evaluation: + return RolloutFnEvalOutput(data={"benchmark": {"accuracy": 0.98}}) + return RolloutFnTrainOutput(samples=[[{"text": "async"}]]) + + with function_registry.temporary("test:async_class", AsyncRolloutFn): + fn = load_rollout_function(constructor_input, "test:async_class") + + input_cls = RolloutFnEvalInput if evaluation else RolloutFnTrainInput + result = call_rollout_function(fn, input_cls(rollout_id=1)) + + assert isinstance(fn, AsyncRolloutFn) + expected_type = RolloutFnEvalOutput if evaluation else RolloutFnTrainOutput + assert isinstance(result, expected_type) + + +class TestSupportedGenerateFormats: + """ + Documentation test similar to TestSupportedRolloutFormats + """ + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_1_legacy_function_with_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params, evaluation=False): + return "my_sample" + + with function_registry.temporary("test:legacy_gen_eval", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen_eval") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_2_legacy_function_without_evaluation_param(self, make_generate_fn_input, evaluation): + async def legacy_generate_fn(args, sample, sampling_params): + return "my_sample" + + with function_registry.temporary("test:legacy_gen", legacy_generate_fn): + fn = load_generate_function("test:legacy_gen") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, LegacyGenerateFnAdapter) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_3_new_async_function_api(self, make_generate_fn_input, evaluation): + async def generate(input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(samples="my_sample") + + with function_registry.temporary("test:new_async", generate): + fn = load_generate_function("test:new_async") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" + + @pytest.mark.parametrize("evaluation", [False, True]) + def test_format_4_new_class_api(self, make_generate_fn_input, evaluation): + class MyGenerateFn: + async def __call__(self, input: GenerateFnInput) -> GenerateFnOutput: + return GenerateFnOutput(samples="my_sample") + + with function_registry.temporary("test:new_class", MyGenerateFn): + fn = load_generate_function("test:new_class") + + result = run(fn(make_generate_fn_input(evaluation))) + + assert isinstance(fn, MyGenerateFn) + assert isinstance(result, GenerateFnOutput) + assert result.samples == "my_sample" diff --git a/tests/fast/rollout/rm_hub/__init__.py b/tests/fast/rollout/rm_hub/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/rollout/rm_hub/test_deepscaler.py b/tests/fast/rollout/rm_hub/test_deepscaler.py new file mode 100644 index 000000000..bd4c606a6 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_deepscaler.py @@ -0,0 +1,26 @@ +import pytest + +from miles.rollout.rm_hub.deepscaler import get_deepscaler_rule_based_reward + + +class TestGetDeepscalerRuleBasedReward: + @pytest.mark.parametrize( + "response,label,expected", + [ + (r"Let me analyze...The answer is \boxed{42}", "42", 1), + (r"Thinking...The answer is \boxed{wrong}", "42", 0), + (r"###Response\boxed{42}", "42", 1), + (r"###Response\boxed{wrong}", "42", 0), + (r"The answer is \boxed{42}", "42", 0), + (r"The answer is 42", "42", 0), + (r"\boxed{42}", "", 0), + (r"\boxed{42}", r"\boxed{42}", 1), + (r"\boxed{123}", 123, 1), + (r"\boxed{3.14}", 3.14, 1), + (r"\boxed{1/2}", "0.5", 1), + (r"\boxed{\frac{1}{2}}", "0.5", 1), + (r"First thoughtSecond thought\boxed{42}", "42", 1), + ], + ) + def test_get_deepscaler_rule_based_reward(self, response, label, expected): + assert get_deepscaler_rule_based_reward(response, label) == expected diff --git a/tests/fast/rollout/rm_hub/test_f1.py b/tests/fast/rollout/rm_hub/test_f1.py new file mode 100644 index 000000000..c9ecf9614 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_f1.py @@ -0,0 +1,44 @@ +import pytest + +from miles.rollout.rm_hub.f1 import f1_score, normalize_answer + + +class TestNormalizeAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("The quick brown fox", "quick brown fox"), + ("A cat and a dog", "cat and dog"), + ("Hello, world!", "hello world"), + (" multiple spaces ", "multiple spaces"), + ("An apple", "apple"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_normalize_answer(self, input_str, expected): + assert normalize_answer(input_str) == expected + + +class TestF1Score: + @pytest.mark.parametrize( + "prediction,ground_truth,expected_f1,expected_prec,expected_recall", + [ + ("hello world", "hello world", 1.0, 1.0, 1.0), + ("hello world foo", "hello world bar", 2 / 3, 2 / 3, 2 / 3), + ("abc", "xyz", 0, 0, 0), + (None, "anything", 0, 0, 0), + ("yes", "no", 0, 0, 0), + ("no", "yes", 0, 0, 0), + ("yes", "yes", 1.0, 1.0, 1.0), + ("noanswer", "yes", 0, 0, 0), + ("the answer is correct", "answer is correct", 1.0, 1.0, 1.0), + ("hello, world!", "hello world", 1.0, 1.0, 1.0), + ("hello", "hello world", pytest.approx(2 / 3), 1.0, 0.5), + ], + ) + def test_f1_score(self, prediction, ground_truth, expected_f1, expected_prec, expected_recall): + f1, prec, recall = f1_score(prediction, ground_truth) + assert f1 == expected_f1 + assert prec == expected_prec + assert recall == expected_recall diff --git a/tests/fast/rollout/rm_hub/test_gpqa.py b/tests/fast/rollout/rm_hub/test_gpqa.py new file mode 100644 index 000000000..45cefd201 --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_gpqa.py @@ -0,0 +1,86 @@ +import pytest + +from miles.rollout.rm_hub.gpqa import ( + _extract_letter_from_response, + _normalize_text, + _strip_chain_of_thought, + compute_gpqa_reward, +) + + +class TestStripChainOfThought: + @pytest.mark.parametrize( + "text,expected", + [ + ("Let me think...The answer is A", "The answer is A"), + ("The answer is A", "The answer is A"), + ("", ""), + (None, ""), + ], + ) + def test_strip_chain_of_thought(self, text, expected): + assert _strip_chain_of_thought(text) == expected + + +class TestNormalizeText: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("Hello World", "hello world"), + ("Test-123", "test 123"), + ("A, B, C", "a b c"), + ("", ""), + ], + ) + def test_normalize_text(self, input_str, expected): + assert _normalize_text(input_str) == expected + + +class TestExtractLetterFromResponse: + @pytest.mark.parametrize( + "response,expected", + [ + ("The answer is A", "A"), + ("answer: B", "B"), + ("I think C is correct", "C"), + ("final answer: D", "D"), + ("Option A is the best choice", "A"), + ("The answer is B", "B"), + ("After analysis, my choice is C", "C"), + ("A B C D", "D"), + ("No valid letter here", None), + ("", None), + (None, None), + ("The answer is Z", None), + ], + ) + def test_extract_letter(self, response, expected): + assert _extract_letter_from_response(response, "ABCD") == expected + + +class TestComputeGpqaReward: + @pytest.mark.parametrize( + "response,label,metadata,expected", + [ + ("Answer: A", "A", None, 1.0), + ("Answer: A", "B", None, 0.0), + (None, "A", None, 0.0), + ("Answer: B", "ignored", {"correct_letter": "B"}, 1.0), + ("Answer: A", "ignored", {"correct_letter": "B"}, 0.0), + ("Answer: A", 0, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: B", 1, {"choices": ["Option 1", "Option 2", "Option 3", "Option 4"]}, 1.0), + ("Answer: X", "X", {"valid_letters": ["X", "Y", "Z"]}, 1.0), + ("Answer: A", "X", {"valid_letters": ["X", "Y", "Z"]}, 0.0), + ( + "I believe the answer is Paris", + "", + {"choices": ["Paris", "London", "Berlin", "Rome"], "correct_letter": "A"}, + 1.0, + ), + ("Answer: A", "", {"choices": {"A": "Paris", "B": "London"}, "correct_letter": "A"}, 1.0), + ("The answer is Paris", "Paris", {"choices": ["Paris", "London", "Berlin", "Rome"]}, 1.0), + ("Let me think step by step...The answer is A", "A", None, 1.0), + ], + ) + def test_compute_gpqa_reward(self, response, label, metadata, expected): + assert compute_gpqa_reward(response, label, metadata=metadata) == expected diff --git a/tests/fast/rollout/rm_hub/test_math_dapo_utils.py b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py new file mode 100644 index 000000000..56a7f6d1f --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_math_dapo_utils.py @@ -0,0 +1,108 @@ +import pytest + +from miles.rollout.rm_hub.math_dapo_utils import ( + compute_score, + is_correct_minerva, + is_correct_strict_box, + last_boxed_only_string, + normalize_final_answer, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2}", r"\boxed{x^2}"), + (r"No boxed", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x + 1}", "x + 1"), + ], + ) + def test_remove_boxed_valid(self, input_str, expected): + assert remove_boxed(input_str) == expected + + def test_remove_boxed_invalid(self): + with pytest.raises(AssertionError): + remove_boxed("not boxed") + + +class TestNormalizeFinalAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("42", "42"), + (" 42 ", "42"), + (r"\text{hello}", "hello"), + (r"\textbf{bold}", "bold"), + (r"x = 42", "42"), + (r"100 square", "100"), + (r"$50$ dollars", "50"), + (r"\boxed{42}", "42"), + (r"\frac12", r"\frac{1}{2}"), + (r"\sqrt3", r"\sqrt{3}"), + ("1,000", "1000"), + ("<|im_end|>", ""), + ], + ) + def test_normalize_final_answer(self, input_str, expected): + assert normalize_final_answer(input_str) == expected + + +class TestIsCorrectMinerva: + @pytest.mark.parametrize( + "solution,gt,gt_need_extract,expected_correct", + [ + ("Answer: 42", "42", False, True), + ("Answer: 100", "42", False, False), + ("Answer: wrong", "42", False, False), + ("Answer: 42", r"\boxed{42}", True, True), + ], + ) + def test_is_correct_minerva(self, solution, gt, gt_need_extract, expected_correct): + correct, pred = is_correct_minerva(solution, gt, gt_need_extract=gt_need_extract) + assert correct == expected_correct + + +class TestIsCorrectStrictBox: + @pytest.mark.parametrize( + "pred,gt,expected_score,expected_pred", + [ + (r"blah blah \boxed{42}", "42", 1, "42"), + (r"\boxed{wrong}", "42", -1, "wrong"), + ("no box here", "42", -1, None), + ], + ) + def test_is_correct_strict_box(self, pred, gt, expected_score, expected_pred): + score, extracted = is_correct_strict_box(pred, gt) + assert score == expected_score + assert extracted == expected_pred + + +class TestComputeScore: + @pytest.mark.parametrize( + "solution,gt,strict_box,expected_score,expected_acc", + [ + ("Answer: 42", "42", False, 1.0, True), + ("Answer: wrong", "42", False, -1.0, False), + (r"\boxed{42}", "42", True, 1.0, True), + ("x" * 500 + " Answer: 42", "42", False, 1.0, True), + ], + ) + def test_compute_score(self, solution, gt, strict_box, expected_score, expected_acc): + result = compute_score(solution, gt, strict_box_verify=strict_box) + assert result["score"] == expected_score + assert result["acc"] == expected_acc diff --git a/tests/fast/rollout/rm_hub/test_math_utils.py b/tests/fast/rollout/rm_hub/test_math_utils.py new file mode 100644 index 000000000..2423ed4ac --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_math_utils.py @@ -0,0 +1,129 @@ +import pytest + +from miles.rollout.rm_hub.math_utils import ( + _normalize, + extract_answer, + grade_answer_mathd, + grade_answer_sympy, + grade_answer_verl, + last_boxed_only_string, + remove_boxed, +) + + +class TestLastBoxedOnlyString: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", r"\boxed{42}"), + (r"\boxed{x^2 + 1}", r"\boxed{x^2 + 1}"), + (r"So \boxed{\frac{1}{2}}", r"\boxed{\frac{1}{2}}"), + (r"No boxed here", None), + (r"Multiple \boxed{1} and \boxed{2}", r"\boxed{2}"), + (r"\boxed{nested {braces}}", r"\boxed{nested {braces}}"), + (r"\fbox{fbox content}", r"\fbox{fbox content}"), + ("", None), + ], + ) + def test_last_boxed_only_string(self, input_str, expected): + assert last_boxed_only_string(input_str) == expected + + +class TestRemoveBoxed: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"\boxed{42}", "42"), + (r"\boxed{x^2 + 1}", "x^2 + 1"), + (r"\boxed{\frac{1}{2}}", r"\frac{1}{2}"), + ("not boxed", None), + ], + ) + def test_remove_boxed(self, input_str, expected): + assert remove_boxed(input_str) == expected + + +class TestExtractAnswer: + @pytest.mark.parametrize( + "input_str,expected", + [ + (r"The answer is \boxed{42}", "42"), + (r"So \boxed{\frac{1}{2}}", r"\frac{1}{2}"), + (r"Multiple \boxed{1} then \boxed{final}", "final"), + (r"No boxed here", None), + ("", None), + ], + ) + def test_extract_answer(self, input_str, expected): + assert extract_answer(input_str) == expected + + +class TestNormalize: + @pytest.mark.parametrize( + "input_str,expected", + [ + ("1,000", "1000"), + (r"\text{hello}", "hello"), + (" 42 ", "42"), + (r"100%", "100"), + (r"\$50", "50"), + ("HELLO", "hello"), + ("1,234,567", "1234567"), + (None, None), + ], + ) + def test_normalize(self, input_str, expected): + assert _normalize(input_str) == expected + + +class TestGradeAnswerMathd: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + (" 42 ", "42", True), + (r"\frac{1}{2}", r"\frac{1}{2}", True), + ("wrong", "42", False), + ("", "42", False), + ], + ) + def test_grade_answer_mathd(self, given, ground_truth, expected): + assert grade_answer_mathd(given, ground_truth) == expected + + +class TestGradeAnswerSympy: + @pytest.mark.parametrize( + "given,ground_truth,expected", + [ + ("42", "42", True), + ("x^2", "x^2", True), + ("1/2", "0.5", True), + (r"\frac{1}{2}", "0.5", True), + ("wrong", "42", False), + ("", "42", False), + ("(1,2)", "(1,2)", True), + ("(1,2,3)", "(1,2)", False), + ("42", None, False), + ], + ) + def test_grade_answer_sympy(self, given, ground_truth, expected): + assert grade_answer_sympy(given, ground_truth) == expected + + +class TestGradeAnswerVerl: + @pytest.mark.parametrize( + "solution,ground_truth,expected", + [ + (r"\boxed{42}", "42", True), + (r"The answer is \boxed{42}", "42", True), + (r"\boxed{1/2}", r"\frac{1}{2}", True), + (r"\boxed{wrong}", "42", False), + ("no boxed", "42", False), + (r"\boxed{42}", r"\boxed{42}", True), + ("", "42", False), + (r"\boxed{42}", "", False), + (r"\boxed{42}", None, False), + ], + ) + def test_grade_answer_verl(self, solution, ground_truth, expected): + assert grade_answer_verl(solution, ground_truth) == expected diff --git a/tests/fast/rollout/rm_hub/test_rm_hub.py b/tests/fast/rollout/rm_hub/test_rm_hub.py new file mode 100644 index 000000000..a3dadbdaf --- /dev/null +++ b/tests/fast/rollout/rm_hub/test_rm_hub.py @@ -0,0 +1,126 @@ +from unittest.mock import MagicMock + +import pytest + +from miles.rollout.rm_hub import async_rm, batched_async_rm +from miles.utils.async_utils import run +from miles.utils.types import Sample + + +@pytest.fixture +def mock_args(): + args = MagicMock() + args.custom_rm_path = None + args.rm_type = None + args.rm_url = None + return args + + +class TestAsyncRm: + @pytest.mark.parametrize( + "rm_type,response,label,expected", + [ + ("math", r"\boxed{42}", "42", 1), + ("math", r"\boxed{wrong}", "42", 0), + ("f1", "hello world", "hello world", 1.0), + ("dapo", "Answer: 42", "42", {"score": 1.0}), + ("deepscaler", r"\boxed{42}", "42", 1), + ("gpqa", "Answer: A", "A", 1.0), + ("boxed_f1", r"Final answer is \boxed{hello world}", "hello world", 1.0), + ], + ) + def test_rm_types(self, mock_args, rm_type, response, label, expected): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response=response, label=label) + reward = run(async_rm(mock_args, sample)) + if isinstance(expected, dict): + for k, v in expected.items(): + assert reward[k] == v + else: + assert reward == expected + + def test_f1_rm_partial(self, mock_args): + mock_args.rm_type = "f1" + sample = Sample(prompt="", response="hello", label="hello world") + reward = run(async_rm(mock_args, sample)) + assert 0 < reward < 1 + + def test_random_rm(self, mock_args): + mock_args.rm_type = "random" + sample = Sample(prompt="", response="anything", label="anything") + reward = run(async_rm(mock_args, sample)) + assert reward in [0, 1] + + def test_rm_type_from_metadata(self, mock_args): + mock_args.rm_type = None + sample = Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}) + reward = run(async_rm(mock_args, sample)) + assert reward == 1 + + @pytest.mark.parametrize( + "rm_type,match", + [ + ("unknown_type", "not implemented"), + ("", "not specified"), + ], + ) + def test_invalid_rm_type_raises(self, mock_args, rm_type, match): + mock_args.rm_type = rm_type + sample = Sample(prompt="", response="test", label="test") + with pytest.raises(NotImplementedError, match=match): + run(async_rm(mock_args, sample)) + + +class TestBatchedAsyncRm: + @pytest.mark.parametrize( + "rm_type,samples_data,expected", + [ + ( + "math", + [(r"\boxed{42}", "42"), (r"\boxed{100}", "100"), (r"\boxed{wrong}", "42")], + [1, 1, 0], + ), + ( + "f1", + [("hello world", "hello world"), ("different", "something else")], + [1.0, 0], + ), + ], + ) + def test_batched_rm(self, mock_args, rm_type, samples_data, expected): + mock_args.rm_type = rm_type + samples = [Sample(prompt="", response=r, label=label) for r, label in samples_data] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards == expected + + def test_inplace_set_reward_field(self, mock_args): + mock_args.rm_type = "math" + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42"), + Sample(prompt="", response=r"\boxed{100}", label="100"), + ] + result = run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + assert result is None + assert samples[0].reward == 1 + assert samples[1].reward == 1 + + def test_inplace_raises_on_existing_reward(self, mock_args): + mock_args.rm_type = "math" + samples = [Sample(prompt="", response=r"\boxed{42}", label="42", reward=0.5)] + with pytest.raises(AssertionError, match="Overriding"): + run(batched_async_rm(mock_args, samples, inplace_set_reward_field=True)) + + def test_empty_samples(self, mock_args): + mock_args.rm_type = "math" + rewards = run(batched_async_rm(mock_args, [])) + assert rewards == [] + + def test_mixed_rm_types_via_metadata(self, mock_args): + mock_args.rm_type = None + samples = [ + Sample(prompt="", response=r"\boxed{42}", label="42", metadata={"rm_type": "math"}), + Sample(prompt="", response="hello", label="hello", metadata={"rm_type": "f1"}), + ] + rewards = run(batched_async_rm(mock_args, samples)) + assert rewards[0] == 1 + assert rewards[1] == 1.0 diff --git a/tests/fast/router/__init__.py b/tests/fast/router/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/router/test_router.py b/tests/fast/router/test_router.py new file mode 100644 index 000000000..8bad5874c --- /dev/null +++ b/tests/fast/router/test_router.py @@ -0,0 +1,207 @@ +import asyncio +from argparse import Namespace + +import pytest +import requests + +from miles.router.router import MilesRouter +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import MockSGLangServer, default_process_fn +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + +def make_router_args(router_port: int, **overrides) -> Namespace: + defaults = dict( + sglang_router_ip="127.0.0.1", + sglang_router_port=router_port, + rollout_health_check_interval=1.0, + miles_router_health_check_failure_threshold=3, + miles_router_max_connections=100, + miles_router_timeout=None, + miles_router_middleware_paths=[], + hf_checkpoint="Qwen/Qwen3-0.6B", + cross_turn_token_out=False, + inherit_last_assistant=False, + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def create_mock_worker(start_port: int = 30000) -> MockSGLangServer: + port = find_available_port(start_port) + return MockSGLangServer( + model_name="Qwen/Qwen3-0.6B", + process_fn=default_process_fn, + host="127.0.0.1", + port=port, + latency=0.0, + ) + + +class RouterEnv: + def __init__(self, router: MilesRouter, server: UvicornThreadServer): + self.router = router + self.server = server + + @property + def url(self) -> str: + return self.server.url + + +@pytest.fixture +def router_env(): + args = make_router_args(find_available_port(20000)) + router = MilesRouter(args, verbose=False) + server = UvicornThreadServer(router.app, host=args.sglang_router_ip, port=args.sglang_router_port) + server.start() + yield RouterEnv(router, server) + server.stop() + + +@pytest.fixture +def mock_worker(): + server = create_mock_worker() + server.start() + yield server + server.stop() + + +@pytest.fixture +def mock_worker_factory(): + servers = [] + + def _create(): + start_port = 30000 + len(servers) * 100 + server = create_mock_worker(start_port) + server.start() + servers.append(server) + return server + + yield _create + for s in servers: + s.stop() + + +@pytest.fixture +def router_factory(): + def _create(**overrides) -> MilesRouter: + args = make_router_args(find_available_port(20000), **overrides) + return MilesRouter(args, verbose=False) + + return _create + + +class TestWorkerManagement: + def test_add_worker_via_query_param(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30001" + r = requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + assert router_env.router.worker_request_counts[worker_url] == 0 + + def test_add_worker_via_body(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30002" + r = requests.post(f"{router_env.url}/add_worker", json={"url": worker_url}, timeout=5.0) + r.raise_for_status() + + assert r.json()["status"] == "success" + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_duplicate(self, router_env: RouterEnv): + worker_url = "http://127.0.0.1:30003" + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + requests.post(f"{router_env.url}/add_worker", params={"url": worker_url}, timeout=5.0).raise_for_status() + + assert len(router_env.router.worker_request_counts) == 1 + assert worker_url in router_env.router.worker_request_counts + + def test_add_worker_missing_url(self, router_env: RouterEnv): + r = requests.post(f"{router_env.url}/add_worker", json={}, timeout=5.0) + assert r.status_code == 400 + assert "error" in r.json() + + def test_list_workers(self, router_env: RouterEnv): + worker_urls = ["http://127.0.0.1:30001", "http://127.0.0.1:30002"] + for url in worker_urls: + requests.post(f"{router_env.url}/add_worker", params={"url": url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/list_workers", timeout=5.0) + r.raise_for_status() + assert set(r.json()["urls"]) == set(worker_urls) + + +class TestLoadBalancing: + def test_use_url_selects_min_load(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 2, "http://w3:8000": 8} + + selected = router._use_url() + assert selected == "http://w2:8000" + assert router.worker_request_counts["http://w2:8000"] == 3 + + def test_use_url_excludes_dead_workers(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 5, "http://w2:8000": 1, "http://w3:8000": 3} + router.dead_workers = {"http://w2:8000"} + + selected = router._use_url() + assert selected == "http://w3:8000" + assert router.worker_request_counts["http://w3:8000"] == 4 + + def test_use_url_raises_when_all_dead(self, router_factory): + router = router_factory() + router.worker_request_counts = {"http://w1:8000": 0} + router.dead_workers = {"http://w1:8000"} + + with pytest.raises(RuntimeError, match="No healthy workers"): + router._use_url() + + +# TODO: extract main body inside `_health_check_loop`, then can test that function +class TestHealthCheck: + def test_check_worker_health_success(self, router_factory, mock_worker: MockSGLangServer): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health(mock_worker.url)) + assert url == mock_worker.url + assert healthy is True + + def test_check_worker_health_failure(self, router_factory): + router = router_factory() + url, healthy = asyncio.run(router._check_worker_health("http://127.0.0.1:59999")) + assert url == "http://127.0.0.1:59999" + assert healthy is False + + +class TestProxyIntegration: + def test_proxy_forwards_request(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0).raise_for_status() + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + r = requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0) + r.raise_for_status() + + assert "text" in r.json() + assert len(mock_worker.request_log) == 1 + assert mock_worker.request_log[0] == payload + + def test_proxy_multi_worker(self, router_env: RouterEnv, mock_worker_factory): + worker1, worker2 = mock_worker_factory(), mock_worker_factory() + requests.post(f"{router_env.url}/add_worker", params={"url": worker1.url}, timeout=5.0) + requests.post(f"{router_env.url}/add_worker", params={"url": worker2.url}, timeout=5.0) + + payload = {"input_ids": [1, 2, 3], "return_logprob": True} + for _ in range(4): + requests.post(f"{router_env.url}/generate", json=payload, timeout=10.0).raise_for_status() + + all_requests = worker1.request_log + worker2.request_log + assert len(all_requests) == 4 + assert all(req == payload for req in all_requests) + + def test_proxy_health_endpoint(self, router_env: RouterEnv, mock_worker: MockSGLangServer): + requests.post(f"{router_env.url}/add_worker", params={"url": mock_worker.url}, timeout=5.0) + + r = requests.get(f"{router_env.url}/health", timeout=5.0) + r.raise_for_status() + assert r.json()["status"] == "ok" diff --git a/tests/fast/router/test_seq_trajectory.py b/tests/fast/router/test_seq_trajectory.py new file mode 100644 index 000000000..0705996a8 --- /dev/null +++ b/tests/fast/router/test_seq_trajectory.py @@ -0,0 +1,366 @@ +from types import SimpleNamespace + +import pytest +from transformers import AutoTokenizer + +from miles.rollout.generate_utils.tokenize_utils import tokenize_messages +from miles.router.session import seq_trajectory +from miles.utils.chat_message_utils import get_think_token_start + +MODEL_NAME = "Qwen/Qwen3-4B" +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + + +def _messages(items: list[tuple[str, str]]) -> list[dict[str, str]]: + return [{"role": role, "content": content} for role, content in items] + + +def _token_info_from_ids(token_ids: list[int]) -> seq_trajectory.TokenInfo: + return seq_trajectory.TokenInfo( + tokens=TOKENIZER.convert_ids_to_tokens(token_ids), + token_ids=token_ids, + log_probs=[0.0] * len(token_ids), + loss_mask=[1] * len(token_ids), + ) + + +def _turn(messages: list[dict[str, str]], prompt_ids: list[int], response_ids: list[int]) -> seq_trajectory.Turn: + payload = { + "messages": messages, + "prompt_tokens": _token_info_from_ids(prompt_ids), + "response_tokens": _token_info_from_ids(response_ids), + } + if hasattr(seq_trajectory.Turn, "model_construct"): + return seq_trajectory.Turn.model_construct(**payload) + return seq_trajectory.Turn.construct(**payload) + + +def _turn_from_messages(messages: list[dict[str, str]]) -> seq_trajectory.Turn: + prompt_token_ids = TOKENIZER.apply_chat_template( + messages[:-1], + tokenize=True, + add_generation_prompt=True, + ) + response_token_ids = TOKENIZER.encode(messages[-1]["content"], add_special_tokens=False) + return _turn(messages, prompt_token_ids, response_token_ids) + + +def _assert_prompt_token_info(token_info: seq_trajectory.TokenInfo, expected_token_ids: list[int]) -> None: + assert token_info.token_ids == expected_token_ids + assert token_info.tokens == TOKENIZER.convert_ids_to_tokens(expected_token_ids) + assert token_info.log_probs == [0.0] * len(expected_token_ids) + assert token_info.loss_mask == [0] * len(expected_token_ids) + + +def _make_manager(*, cross_turn_token_out: bool, inherit_last_assistant: bool) -> seq_trajectory.SeqTrajectoryManager: + args = SimpleNamespace( + cross_turn_token_out=cross_turn_token_out, + inherit_last_assistant=inherit_last_assistant, + ) + return seq_trajectory.SeqTrajectoryManager(args, TOKENIZER) + + +def test_turn_match_prefix_messages_returns_remaining(): + messages = _messages([("user", "hi"), ("assistant", "ok"), ("user", "next"), ("assistant", "done")]) + turn = _turn(messages, [], []) + + remaining = turn.match_prefix_messages_and_return_remaining(messages[:2]) + + assert remaining == messages[2:] + + +def test_turn_match_prefix_messages_exact_match_returns_empty(): + messages = _messages([("user", "hi"), ("assistant", "ok")]) + turn = _turn(messages, [], []) + + remaining = turn.match_prefix_messages_and_return_remaining(messages) + + assert remaining == [] + + +def test_turn_match_prefix_messages_mismatch_returns_none(): + messages = _messages([("user", "hi"), ("assistant", "ok")]) + turn = _turn(messages, [], []) + + assert turn.match_prefix_messages_and_return_remaining([{"role": "user", "content": "nope"}]) is None + assert ( + turn.match_prefix_messages_and_return_remaining(messages + [{"role": "assistant", "content": "extra"}]) is None + ) + + +def test_calc_prompt_tokens_info_multi_turn_cross_turn_disabled_uses_last_turn(): + trajectory = seq_trajectory.SeqTrajectory() + turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + turn2_messages = _messages( + [ + ("system", "sys"), + ("user", "u1"), + ("assistant", "a1"), + ("user", "u2"), + ("assistant", "a2"), + ] + ) + + trajectory.insert_new_turn(_turn_from_messages(turn1_messages)) + trajectory.insert_new_turn(_turn_from_messages(turn2_messages)) + + token_info = trajectory.calc_prompt_tokens_info( + turn2_messages, + TOKENIZER, + cross_turn_token_out=False, + inherit_last_assistant=True, + ) + expected_token_ids = TOKENIZER.apply_chat_template(turn2_messages, tokenize=True, add_generation_prompt=True) + _assert_prompt_token_info(token_info, expected_token_ids) + + +def test_calc_prompt_tokens_info_multi_turn_cross_turn_uses_prefix_suffix(): + trajectory = seq_trajectory.SeqTrajectory() + turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + turn2_messages = _messages( + [ + ("system", "sys"), + ("user", "u1"), + ("assistant", "a1"), + ("user", "u2"), + ("assistant", "a2"), + ] + ) + turn1 = _turn_from_messages(turn1_messages) + trajectory.insert_new_turn(turn1) + trajectory.insert_new_turn(_turn_from_messages(turn2_messages)) + + input_messages = _messages([("system", "sys")]) + remain_messages = _messages([("user", "u1"), ("assistant", "a1")]) + + token_info = trajectory.calc_prompt_tokens_info( + input_messages, + TOKENIZER, + cross_turn_token_out=True, + inherit_last_assistant=False, + ) + expected_new_token_ids = tokenize_messages(remain_messages, TOKENIZER, add_generation_prompt=True) + expected_token_ids = turn1.prompt_tokens.token_ids + turn1.response_tokens.token_ids + expected_new_token_ids + _assert_prompt_token_info(token_info, expected_token_ids) + + +def test_calc_prompt_tokens_info_multi_turn_cross_turn_matches_two_turns(): + trajectory = seq_trajectory.SeqTrajectory() + turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + turn2_messages = _messages([("user", "u1"), ("assistant", "a1"), ("user", "u2"), ("assistant", "a2")]) + turn3_messages = _messages([("user", "u3"), ("assistant", "a3")]) + turn2 = _turn_from_messages(turn2_messages) + + trajectory.insert_new_turn(_turn_from_messages(turn1_messages)) + trajectory.insert_new_turn(turn2) + trajectory.insert_new_turn(_turn_from_messages(turn3_messages)) + + input_messages = _messages([("system", "sys")]) + remain_messages = _messages([("user", "u2"), ("assistant", "a2")]) + + token_info = trajectory.calc_prompt_tokens_info( + input_messages, + TOKENIZER, + cross_turn_token_out=True, + inherit_last_assistant=False, + ) + expected_new_token_ids = tokenize_messages(remain_messages, TOKENIZER, add_generation_prompt=True) + expected_token_ids = turn2.prompt_tokens.token_ids + turn2.response_tokens.token_ids + expected_new_token_ids + _assert_prompt_token_info(token_info, expected_token_ids) + + +def test_calc_prompt_tokens_info_multi_turn_cross_turn_empty_remaining_messages(): + trajectory = seq_trajectory.SeqTrajectory() + turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + turn2_messages = _messages( + [ + ("system", "sys"), + ("user", "u1"), + ("assistant", "a1"), + ("user", "u2"), + ("assistant", "a2"), + ] + ) + turn1 = _turn_from_messages(turn1_messages) + + trajectory.insert_new_turn(turn1) + trajectory.insert_new_turn(_turn_from_messages(turn2_messages)) + + token_info = trajectory.calc_prompt_tokens_info( + turn1_messages, + TOKENIZER, + cross_turn_token_out=True, + inherit_last_assistant=False, + ) + expected_token_ids = turn1.prompt_tokens.token_ids + turn1.response_tokens.token_ids + _assert_prompt_token_info(token_info, expected_token_ids) + + +def test_tokenize_messages_trims_complete_think_content(): + messages_with_think = _messages([("assistant", "thoughtanswer")]) + messages_plain = _messages([("assistant", "answer")]) + + tokens_with_think = tokenize_messages(messages_with_think, TOKENIZER, add_generation_prompt=True) + tokens_plain = tokenize_messages(messages_plain, TOKENIZER, add_generation_prompt=True) + + think_start_id = get_think_token_start("qwen3")[1] + + assert tokens_with_think == tokens_plain + assert think_start_id not in tokens_with_think + + +def test_tokenize_messages_does_not_trim_incomplete_think_content(): + messages_incomplete_think = _messages([("assistant", "thought answer")]) + messages_plain = _messages([("assistant", "answer")]) + + tokens_incomplete = tokenize_messages(messages_incomplete_think, TOKENIZER, add_generation_prompt=True) + tokens_plain = tokenize_messages(messages_plain, TOKENIZER, add_generation_prompt=True) + + think_start_id = get_think_token_start("qwen3")[1] + + assert tokens_incomplete != tokens_plain + assert think_start_id in tokens_incomplete + + +def test_manager_calc_prompt_tokens_missing_session_returns_none(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + messages = _messages([("system", "sys"), ("user", "hi")]) + + assert manager.calc_prompt_tokens("missing", messages) is None + + +def test_manager_get_session_by_id_empty_returns_empty_token_info(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + session_id = manager.create_session() + + token_info = manager.get_session_by_id(session_id) + assert token_info is not None + assert token_info.tokens == [] + assert token_info.token_ids == [] + assert token_info.log_probs == [] + assert token_info.loss_mask == [] + + +def test_manager_calc_prompt_tokens_no_turns_retokens_messages(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + session_id = manager.create_session() + messages = _messages([("system", "sys"), ("user", "u1")]) + + token_info = manager.calc_prompt_tokens(session_id, messages) + + expected_token_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + _assert_prompt_token_info(token_info, expected_token_ids) + + +def test_manager_calc_prompt_tokens_inherit_last_assistant_raises(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=True) + session_id = manager.create_session() + turn_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + manager.add_record(session_id, _turn_from_messages(turn_messages)) + + with pytest.raises(NotImplementedError): + manager.calc_prompt_tokens(session_id, turn_messages) + + +def test_manager_calc_prompt_tokens_cross_turn_single_turn_uses_tokenize_messages(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + session_id = manager.create_session() + turn_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + manager.add_record(session_id, _turn_from_messages(turn_messages)) + + messages = _messages([("system", "sys"), ("user", "next")]) + token_info = manager.calc_prompt_tokens(session_id, messages) + + expected_token_ids = tokenize_messages(messages, TOKENIZER, add_generation_prompt=True) + _assert_prompt_token_info(token_info, expected_token_ids) + + +def test_manager_calc_prompt_tokens_cross_turn_multi_turn_prefix_success(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + session_id = manager.create_session() + turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + turn2_messages = _messages([("user", "u2"), ("assistant", "a2")]) + turn1 = _turn_from_messages(turn1_messages) + manager.add_record(session_id, turn1) + manager.add_record(session_id, _turn_from_messages(turn2_messages)) + + input_messages = _messages([("system", "sys")]) + token_info = manager.calc_prompt_tokens(session_id, input_messages) + + remain_messages = _messages([("user", "u1"), ("assistant", "a1")]) + expected_new_token_ids = tokenize_messages(remain_messages, TOKENIZER, add_generation_prompt=True) + expected_token_ids = turn1.prompt_tokens.token_ids + turn1.response_tokens.token_ids + expected_new_token_ids + _assert_prompt_token_info(token_info, expected_token_ids) + + +def test_manager_calc_prompt_tokens_cross_turn_multi_turn_prefix_mismatch_raises(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + session_id = manager.create_session() + turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + manager.add_record(session_id, _turn_from_messages(turn1_messages)) + manager.add_record(session_id, _turn_from_messages(_messages([("user", "u2"), ("assistant", "a2")]))) + + with pytest.raises(ValueError): + manager.calc_prompt_tokens(session_id, _messages([("system", "nope")])) + + +def test_manager_calc_prompt_tokens_cross_turn_disabled_retokens_messages(): + manager = _make_manager(cross_turn_token_out=False, inherit_last_assistant=True) + session_id = manager.create_session() + manager.add_record( + session_id, _turn_from_messages(_messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")])) + ) + + messages = _messages([("system", "sys"), ("user", "new")]) + token_info = manager.calc_prompt_tokens(session_id, messages) + + expected_token_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + _assert_prompt_token_info(token_info, expected_token_ids) + + +def test_manager_get_session_by_id_after_add_record_returns_combined_tokens(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + session_id = manager.create_session() + messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + turn = _turn_from_messages(messages) + manager.add_record(session_id, turn) + + token_info = manager.get_session_by_id(session_id) + + expected_token_ids = turn.prompt_tokens.token_ids + turn.response_tokens.token_ids + assert token_info.token_ids == expected_token_ids + assert token_info.tokens == TOKENIZER.convert_ids_to_tokens(expected_token_ids) + assert token_info.log_probs == turn.prompt_tokens.log_probs + turn.response_tokens.log_probs + assert token_info.loss_mask == turn.prompt_tokens.loss_mask + turn.response_tokens.loss_mask + + +def test_manager_delete_session_by_id(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + session_id = manager.create_session() + + assert manager.delete_session_by_id(session_id) is True + assert manager.delete_session_by_id(session_id) is False + + +def test_manager_add_record_missing_session_raises(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + turn = _turn_from_messages(messages) + + with pytest.raises(ValueError): + manager.add_record("missing", turn) + + +def test_manager_calc_prompt_tokens_cross_turn_multi_turn_empty_remaining_messages(): + manager = _make_manager(cross_turn_token_out=True, inherit_last_assistant=False) + session_id = manager.create_session() + turn1_messages = _messages([("system", "sys"), ("user", "u1"), ("assistant", "a1")]) + turn2_messages = _messages([("user", "u2"), ("assistant", "a2")]) + turn1 = _turn_from_messages(turn1_messages) + manager.add_record(session_id, turn1) + manager.add_record(session_id, _turn_from_messages(turn2_messages)) + + token_info = manager.calc_prompt_tokens(session_id, turn1_messages) + + expected_token_ids = turn1.prompt_tokens.token_ids + turn1.response_tokens.token_ids + _assert_prompt_token_info(token_info, expected_token_ids) diff --git a/tests/fast/router/test_sessions.py b/tests/fast/router/test_sessions.py new file mode 100644 index 000000000..3ab179fde --- /dev/null +++ b/tests/fast/router/test_sessions.py @@ -0,0 +1,314 @@ +from types import SimpleNamespace + +import pytest +import requests +from transformers import AutoTokenizer + +from miles.rollout.generate_utils.tokenize_utils import tokenize_messages +from miles.router.router import MilesRouter +from miles.utils.http_utils import find_available_port +from miles.utils.test_utils.mock_sglang_server import ProcessResult, with_mock_server +from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + + +@pytest.fixture(scope="module") +def router_env(): + def process_fn(_prompt: str) -> ProcessResult: + return ProcessResult(text="ok", finish_reason="stop") + + with with_mock_server(model_name=MODEL_NAME, process_fn=process_fn) as backend: + args = SimpleNamespace( + miles_router_max_connections=10, + miles_router_timeout=30, + miles_router_middleware_paths=[], + rollout_health_check_interval=60, + miles_router_health_check_failure_threshold=3, + hf_checkpoint=MODEL_NAME, + cross_turn_token_out=False, + inherit_last_assistant=False, + ) + router = MilesRouter(args) + + port = find_available_port(31000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=port) + server.start() + + url = f"http://127.0.0.1:{port}" + requests.post(f"{url}/add_worker", json={"url": backend.url}) + + try: + yield {"url": url, "backend": backend} + finally: + server.stop() + + +def _create_session(url: str) -> str: + response = requests.post(f"{url}/sessions") + assert response.status_code == 200 + return response.json()["session_id"] + + +def _extract_response_tokens(response_body: dict) -> tuple[list[int], list[float], list[str]]: + logprobs_content = response_body["choices"][0]["logprobs"]["content"] + token_ids = [item.get("token_id", TOKENIZER.convert_tokens_to_ids(item["token"])) for item in logprobs_content] + logprobs = [item["logprob"] for item in logprobs_content] + tokens = [item["token"] for item in logprobs_content] + return token_ids, logprobs, tokens + + +def test_create_session_and_get_empty_records(router_env): + url = router_env["url"] + session_id = _create_session(url) + + response = requests.get(f"{url}/sessions/{session_id}") + assert response.status_code == 200 + + data = response.json() + assert data["session_id"] == session_id + assert data["records"] == { + "tokens": [], + "token_ids": [], + "log_probs": [], + "loss_mask": [], + } + + +def test_get_session_not_found(router_env): + url = router_env["url"] + response = requests.get(f"{url}/sessions/nonexistent") + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + +def test_delete_session(router_env): + url = router_env["url"] + session_id = _create_session(url) + + delete_resp = requests.delete(f"{url}/sessions/{session_id}") + assert delete_resp.status_code == 204 + assert delete_resp.text == "" + + missing_resp = requests.delete(f"{url}/sessions/{session_id}") + assert missing_resp.status_code == 404 + assert missing_resp.json()["error"] == "session not found" + + +def test_proxy_session_not_found(router_env): + url = router_env["url"] + response = requests.post( + f"{url}/sessions/nonexistent/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}]}, + ) + assert response.status_code == 404 + assert response.json()["error"] == "session not found" + + +def test_proxy_inserts_input_ids_and_records_tokens(router_env): + url = router_env["url"] + backend = router_env["backend"] + backend.reset_stats() + + session_id = _create_session(url) + messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + + response = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": messages}, + ) + assert response.status_code == 200 + + response_body = response.json() + + expected_prompt_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + backend_payload = backend.request_log[-1] + assert backend_payload["input_ids"] == expected_prompt_ids + + response_token_ids, response_logprobs, response_tokens = _extract_response_tokens(response_body) + + get_resp = requests.get(f"{url}/sessions/{session_id}") + assert get_resp.status_code == 200 + + records = get_resp.json()["records"] + expected_token_ids = expected_prompt_ids + response_token_ids + assert records["token_ids"] == expected_token_ids + assert records["tokens"] == TOKENIZER.convert_ids_to_tokens(expected_prompt_ids) + response_tokens + assert records["log_probs"] == [0.0] * len(expected_prompt_ids) + response_logprobs + assert records["loss_mask"] == [0] * len(expected_prompt_ids) + [1] * len(response_token_ids) + + +def test_proxy_preserves_input_ids_when_provided(router_env): + url = router_env["url"] + backend = router_env["backend"] + backend.reset_stats() + + session_id = _create_session(url) + messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + base_prompt_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + custom_input_ids = base_prompt_ids + [base_prompt_ids[-1]] + + response = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": messages, "input_ids": custom_input_ids}, + ) + assert response.status_code == 200 + + backend_payload = backend.request_log[-1] + assert backend_payload["input_ids"] == custom_input_ids + + response_body = response.json() + response_token_ids, response_logprobs, response_tokens = _extract_response_tokens(response_body) + + get_resp = requests.get(f"{url}/sessions/{session_id}") + records = get_resp.json()["records"] + assert records["token_ids"] == response_token_ids + assert records["tokens"] == response_tokens + assert records["log_probs"] == response_logprobs + assert records["loss_mask"] == [1] * len(response_token_ids) + + +def test_proxy_multi_turn_second_call_uses_only_new_messages(router_env): + url = router_env["url"] + backend = router_env["backend"] + backend.reset_stats() + + session_id = _create_session(url) + messages_turn1 = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + response1 = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": messages_turn1}, + ) + assert response1.status_code == 200 + + messages_turn2 = [{"role": "user", "content": "next"}] + response2 = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": messages_turn2}, + ) + assert response2.status_code == 200 + + expected_prompt_ids = tokenize_messages(messages_turn2, TOKENIZER, add_generation_prompt=True) + backend_payload = backend.request_log[-1] + assert backend_payload["input_ids"] == expected_prompt_ids + + response2_body = response2.json() + response2_token_ids, response2_logprobs, response2_tokens = _extract_response_tokens(response2_body) + + get_resp = requests.get(f"{url}/sessions/{session_id}") + records = get_resp.json()["records"] + expected_token_ids = expected_prompt_ids + response2_token_ids + assert records["token_ids"] == expected_token_ids + assert records["tokens"] == TOKENIZER.convert_ids_to_tokens(expected_prompt_ids) + response2_tokens + assert records["log_probs"] == [0.0] * len(expected_prompt_ids) + response2_logprobs + assert records["loss_mask"] == [0] * len(expected_prompt_ids) + [1] * len(response2_token_ids) + + +def test_proxy_third_call_reuses_first_turn_prefix(router_env): + url = router_env["url"] + backend = router_env["backend"] + backend.reset_stats() + + session_id = _create_session(url) + messages_turn1 = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + response1 = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": messages_turn1}, + ) + assert response1.status_code == 200 + + response1_body = response1.json() + response1_token_ids, _, _ = _extract_response_tokens(response1_body) + prompt1_ids = TOKENIZER.apply_chat_template(messages_turn1, tokenize=True, add_generation_prompt=True) + + response2 = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": [{"role": "user", "content": "next"}]}, + ) + assert response2.status_code == 200 + + assistant_message = response1_body["choices"][0]["message"] + messages_turn3 = [{"role": "system", "content": "sys"}] + response3 = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": messages_turn3}, + ) + assert response3.status_code == 200 + + remain_messages = [messages_turn1[1], assistant_message] + expected_prompt_ids = ( + prompt1_ids + + response1_token_ids + + tokenize_messages( + remain_messages, + TOKENIZER, + add_generation_prompt=True, + ) + ) + backend_payload = backend.request_log[-1] + assert backend_payload["input_ids"] == expected_prompt_ids + + response3_body = response3.json() + response3_token_ids, response3_logprobs, response3_tokens = _extract_response_tokens(response3_body) + + get_resp = requests.get(f"{url}/sessions/{session_id}") + records = get_resp.json()["records"] + expected_token_ids = expected_prompt_ids + response3_token_ids + assert records["token_ids"] == expected_token_ids + assert records["tokens"] == TOKENIZER.convert_ids_to_tokens(expected_prompt_ids) + response3_tokens + assert records["log_probs"] == [0.0] * len(expected_prompt_ids) + response3_logprobs + assert records["loss_mask"] == [0] * len(expected_prompt_ids) + [1] * len(response3_token_ids) + + +def test_proxy_respects_token_id_in_logprobs(router_env): + url = router_env["url"] + backend = router_env["backend"] + backend.reset_stats() + + original_compute = backend._compute_chat_completions_response + + def _custom_compute(payload: dict) -> dict: + response = original_compute(payload) + for idx, item in enumerate(response["choices"][0]["logprobs"]["content"]): + item["token_id"] = 900000 + idx + return response + + backend._compute_chat_completions_response = _custom_compute + try: + session_id = _create_session(url) + messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + response = requests.post( + f"{url}/sessions/{session_id}/v1/chat/completions", + json={"messages": messages}, + ) + assert response.status_code == 200 + + response_body = response.json() + custom_ids, response_logprobs, response_tokens = _extract_response_tokens(response_body) + prompt_ids = TOKENIZER.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + + get_resp = requests.get(f"{url}/sessions/{session_id}") + records = get_resp.json()["records"] + expected_token_ids = prompt_ids + custom_ids + assert records["token_ids"] == expected_token_ids + assert records["tokens"] == TOKENIZER.convert_ids_to_tokens(prompt_ids) + response_tokens + assert records["log_probs"] == [0.0] * len(prompt_ids) + response_logprobs + assert records["loss_mask"] == [0] * len(prompt_ids) + [1] * len(custom_ids) + finally: + backend._compute_chat_completions_response = original_compute diff --git a/tests/fast/utils/__init__.py b/tests/fast/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/utils/test_arguments.py b/tests/fast/utils/test_arguments.py new file mode 100644 index 000000000..9bd1a620d --- /dev/null +++ b/tests/fast/utils/test_arguments.py @@ -0,0 +1,58 @@ +import argparse +import sys +from unittest.mock import patch + +import pytest + +from miles.utils.arguments import get_miles_extra_args_provider +from miles.utils.misc import function_registry + +PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"] +REQUIRED_ARGS = ["--rollout-batch-size", "64"] + + +def make_class_with_add_arguments(): + class MyFn: + @classmethod + def add_arguments(cls, parser): + parser.add_argument("--my-custom-arg", type=int, default=42) + + return MyFn + + +def make_function_with_add_arguments(): + def my_fn(): + pass + + my_fn.add_arguments = lambda parser: parser.add_argument("--my-custom-arg", type=int, default=42) + return my_fn + + +def make_function_without_add_arguments(): + def my_fn(): + pass + + return my_fn + + +@pytest.mark.parametrize("path_arg", PATH_ARGS) +class TestAddArgumentsSupport: + + @pytest.mark.parametrize("fn_factory", [make_class_with_add_arguments, make_function_with_add_arguments]) + def test_add_arguments_is_called_and_arg_is_parsed(self, path_arg, fn_factory): + fn = fn_factory() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn", "--my-custom-arg", "100"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) + args, _ = parser.parse_known_args() + assert args.my_custom_arg == 100 + + def test_skips_function_without_add_arguments(self, path_arg): + fn = make_function_without_add_arguments() + with function_registry.temporary("test:fn", fn), patch.object( + sys, "argv", ["test", path_arg, "test:fn"] + REQUIRED_ARGS + ): + parser = argparse.ArgumentParser() + get_miles_extra_args_provider()(parser) diff --git a/tests/utils/test_mask_utils.py b/tests/fast/utils/test_mask_utils.py similarity index 100% rename from tests/utils/test_mask_utils.py rename to tests/fast/utils/test_mask_utils.py diff --git a/tests/fast/utils/test_misc.py b/tests/fast/utils/test_misc.py new file mode 100644 index 000000000..810c2b67c --- /dev/null +++ b/tests/fast/utils/test_misc.py @@ -0,0 +1,59 @@ +import os + +import pytest + +from miles.utils.misc import FunctionRegistry, function_registry, load_function + + +def _fn_a(): + return "a" + + +def _fn_b(): + return "b" + + +class TestFunctionRegistry: + def test_register_and_get(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + + def test_register_duplicate_raises(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + with pytest.raises(AssertionError): + with registry.temporary("my_fn", _fn_b): + pass + + def test_unregister(self): + registry = FunctionRegistry() + with registry.temporary("my_fn", _fn_a): + assert registry.get("my_fn") is _fn_a + assert registry.get("my_fn") is None + + def test_temporary_cleanup_on_exception(self): + registry = FunctionRegistry() + with pytest.raises(RuntimeError): + with registry.temporary("temp_fn", _fn_a): + raise RuntimeError("test") + assert registry.get("temp_fn") is None + + +class TestLoadFunction: + def test_load_from_module(self): + import os.path + + assert load_function("os.path.join") is os.path.join + + def test_load_none_returns_none(self): + assert load_function(None) is None + + def test_load_from_registry(self): + with function_registry.temporary("test:my_fn", _fn_a): + assert load_function("test:my_fn") is _fn_a + + def test_registry_takes_precedence(self): + with function_registry.temporary("os.path.join", _fn_b): + assert load_function("os.path.join") is _fn_b + assert load_function("os.path.join") is os.path.join diff --git a/tests/fast/utils/test_utils/__init__.py b/tests/fast/utils/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fast/utils/test_utils/test_mock_sglang_server.py b/tests/fast/utils/test_utils/test_mock_sglang_server.py new file mode 100644 index 000000000..6633678da --- /dev/null +++ b/tests/fast/utils/test_utils/test_mock_sglang_server.py @@ -0,0 +1,409 @@ +import asyncio +import concurrent.futures +import time + +import pytest +import requests + +from miles.utils.test_utils.mock_sglang_server import ( + Counter, + ProcessResult, + ProcessResultMetaInfo, + default_process_fn, + with_mock_server, +) +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub + + +def expected_logprobs(tokenizer, text: str) -> list[dict]: + output_ids = tokenizer.encode(text, add_special_tokens=False) + return [{"token": tokenizer.convert_ids_to_tokens(tid), "logprob": -i / 128} for i, tid in enumerate(output_ids)] + + +@pytest.fixture(scope="module") +def mock_server(): + with with_mock_server() as server: + yield server + + +class TestProcessResultMetaInfo: + def test_to_dict_empty(self): + assert ProcessResultMetaInfo().to_dict() == {} + + def test_to_dict_single_field(self): + assert ProcessResultMetaInfo(weight_version="v1").to_dict() == {"weight_version": "v1"} + + def test_to_dict_partial_fields(self): + assert ProcessResultMetaInfo(weight_version="v1", spec_accept_token_num=10).to_dict() == { + "weight_version": "v1", + "spec_accept_token_num": 10, + } + + def test_to_dict_all_fields(self): + assert ProcessResultMetaInfo( + weight_version="v1", + routed_experts="abc", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ).to_dict() == { + "weight_version": "v1", + "routed_experts": "abc", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + } + + +class TestDefaultProcessFn: + def test_math_question(self): + assert default_process_fn("What is 1+5?") == ProcessResult(text="\\boxed{6}", finish_reason="stop") + assert default_process_fn("What is 1+10?") == ProcessResult(text="\\boxed{11}", finish_reason="stop") + + def test_unknown_question(self): + assert default_process_fn("Hello") == ProcessResult(text="I don't understand.", finish_reason="stop") + + +class TestCounter: + def test_tracks_max(self): + counter = Counter() + assert counter.max_value == 0 + + with counter.track(): + assert counter.max_value == 1 + with counter.track(): + assert counter.max_value == 2 + + counter.reset() + assert counter.max_value == 0 + + def test_concurrent_tasks(self): + counter = Counter() + + async def task(): + with counter.track(): + await asyncio.sleep(0.1) + + async def run_all(): + await asyncio.gather(task(), task(), task()) + + asyncio.run(run_all()) + assert counter.max_value == 3 + + +class TestMockServerBasic: + def test_start_stop(self, mock_server): + assert mock_server.port > 0 + assert f"http://{mock_server.host}:{mock_server.port}" == mock_server.url + + def test_request_log_and_reset_stats(self, mock_server): + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + + payload = {"input_ids": [1, 2, 3], "sampling_params": {"temperature": 0.5}, "return_logprob": True} + requests.post(f"{mock_server.url}/generate", json=payload, timeout=5.0) + assert len(mock_server.request_log) == 1 + assert mock_server.request_log[0] == payload + + mock_server.reset_stats() + assert len(mock_server.request_log) == 0 + assert mock_server.max_concurrent == 0 + + @pytest.mark.parametrize("latency,min_time,max_time", [(0.0, 0.0, 0.3), (0.5, 0.5, 1.0)]) + def test_latency(self, latency, min_time, max_time): + with with_mock_server(latency=latency) as server: + start = time.time() + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + elapsed = time.time() - start + assert min_time <= elapsed < max_time + + def test_max_concurrent_with_latency(self): + with with_mock_server(latency=0.1) as server: + + def send_request(): + requests.post(f"{server.url}/generate", json={"input_ids": [1], "sampling_params": {}}, timeout=5.0) + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(send_request) for _ in range(3)] + concurrent.futures.wait(futures) + + assert server.max_concurrent == 3 + + def test_health_endpoint(self, mock_server): + response = requests.get(f"{mock_server.url}/health", timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + def test_abort_request_endpoint(self, mock_server): + response = requests.post(f"{mock_server.url}/abort_request", json={}, timeout=5.0) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +class TestGenerateEndpoint: + def test_basic(self, mock_server): + prompt = "What is 1+7?" + input_ids = mock_server.tokenizer.encode(prompt, add_special_tokens=False) + assert input_ids == [3838, 374, 220, 16, 10, 22, 30] + + response = requests.post( + f"{mock_server.url}/generate", + json={ + "input_ids": input_ids, + "sampling_params": {"temperature": 0.7, "max_new_tokens": 10}, + "return_logprob": True, + }, + timeout=5.0, + ) + assert response.status_code == 200 + assert response.json() == { + "text": "\\boxed{8}", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": len(input_ids), + "cached_tokens": 0, + "completion_tokens": 5, + "output_token_logprobs": [ + [-0.0, 59], + [-0.0078125, 79075], + [-0.015625, 90], + [-0.0234375, 23], + [-0.03125, 92], + ], + }, + } + + def test_with_meta_info(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult( + text="ok", + finish_reason="stop", + cached_tokens=5, + meta_info=ProcessResultMetaInfo( + weight_version="v2.0", + routed_experts="encoded_data", + spec_accept_token_num=10, + spec_draft_token_num=15, + spec_verify_ct=3, + ), + ) + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + + assert response.json() == { + "text": "ok", + "meta_info": { + "finish_reason": {"type": "stop"}, + "prompt_tokens": 3, + "cached_tokens": 5, + "completion_tokens": 1, + "output_token_logprobs": [[-0.0, 562]], + "weight_version": "v2.0", + "routed_experts": "encoded_data", + "spec_accept_token_num": 10, + "spec_draft_token_num": 15, + "spec_verify_ct": 3, + }, + } + + def test_finish_reason_length(self): + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text="truncated output", finish_reason="length") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/generate", + json={"input_ids": [1, 2, 3], "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + data = response.json() + + finish_reason = data["meta_info"]["finish_reason"] + assert finish_reason["type"] == "length" + assert finish_reason["length"] == data["meta_info"]["completion_tokens"] + + +class TestChatCompletionsEndpoint: + def test_basic(self, mock_server): + response = requests.post( + f"{mock_server.url}/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "What is 1+5?"}], + }, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + + assert data["id"].startswith("chatcmpl-") + assert isinstance(data["created"], int) + assert data == { + "id": data["id"], + "object": "chat.completion", + "created": data["created"], + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "\\boxed{6}", "tool_calls": None}, + "logprobs": {"content": expected_logprobs(mock_server.tokenizer, "\\boxed{6}")}, + "finish_reason": "stop", + } + ], + } + + def test_with_tool_calls(self): + tool_call_response = 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n' + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=tool_call_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year is it?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "Let me check for you.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}} + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, tool_call_response)}, + "finish_reason": "tool_calls", + } + + def test_with_tools_but_no_tool_call(self): + response_text = "The weather is sunny today." + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=response_text, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What's the weather?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": {"role": "assistant", "content": response_text, "tool_calls": None}, + "logprobs": {"content": expected_logprobs(server.tokenizer, response_text)}, + "finish_reason": "stop", + } + + def test_with_multiple_tool_calls(self): + multi_tool_response = ( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n' + ) + + def process_fn(_: str) -> ProcessResult: + return ProcessResult(text=multi_tool_response, finish_reason="stop") + + with with_mock_server(process_fn=process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "What year and temperature?"}], + "tools": SAMPLE_TOOLS, + }, + timeout=5.0, + ) + data = response.json() + + assert data["choices"][0] == { + "index": 0, + "message": { + "role": "assistant", + "content": "I will get year and temperature.", + "tool_calls": [ + {"id": "call00000", "type": "function", "function": {"name": "get_year", "arguments": "{}"}}, + { + "id": "call00001", + "type": "function", + "function": {"name": "get_temperature", "arguments": '{"location": "Shanghai"}'}, + }, + ], + }, + "logprobs": {"content": expected_logprobs(server.tokenizer, multi_tool_response)}, + "finish_reason": "tool_calls", + } + + +class TestMultiTurnToolCallProcessFn: + @pytest.mark.parametrize( + "prompt,expected_response", + [ + pytest.param(TwoTurnStub.FIRST_PROMPT, TwoTurnStub.FIRST_RESPONSE, id="first_turn"), + pytest.param(TwoTurnStub.SECOND_PROMPT, TwoTurnStub.SECOND_RESPONSE, id="second_turn"), + ], + ) + def test_generate_endpoint(self, prompt, expected_response): + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: + input_ids = server.tokenizer.encode(prompt, add_special_tokens=False) + response = requests.post( + f"{server.url}/generate", + json={"input_ids": input_ids, "sampling_params": {}, "return_logprob": True}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["text"] == expected_response + assert data["meta_info"]["finish_reason"] == {"type": "stop"} + + @pytest.mark.parametrize( + "messages,expected_content,expected_tool_calls,expected_finish_reason", + [ + pytest.param( + TwoTurnStub.OPENAI_MESSAGES_FIRST_TURN, + TwoTurnStub.FIRST_RESPONSE_CONTENT, + TwoTurnStub.FIRST_TOOL_CALLS_OPENAI_FORMAT, + "tool_calls", + id="first_turn", + ), + pytest.param( + TwoTurnStub.OPENAI_MESSAGES_SECOND_TURN_FROM_CLIENT, + TwoTurnStub.SECOND_RESPONSE, + None, + "stop", + id="second_turn", + ), + ], + ) + def test_chat_completions_endpoint(self, messages, expected_content, expected_tool_calls, expected_finish_reason): + with with_mock_server(process_fn=TwoTurnStub.process_fn) as server: + response = requests.post( + f"{server.url}/v1/chat/completions", + json={"model": "test", "messages": messages, "tools": SAMPLE_TOOLS}, + timeout=5.0, + ) + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["message"]["content"] == expected_content + assert data["choices"][0]["message"]["tool_calls"] == expected_tool_calls + assert data["choices"][0]["finish_reason"] == expected_finish_reason diff --git a/tests/fast/utils/test_utils/test_mock_tools.py b/tests/fast/utils/test_utils/test_mock_tools.py new file mode 100644 index 000000000..3f2116ec0 --- /dev/null +++ b/tests/fast/utils/test_utils/test_mock_tools.py @@ -0,0 +1,111 @@ +import asyncio + +import pytest +from pydantic import TypeAdapter +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.core_types import ToolCallItem +from sglang.srt.function_call.function_call_parser import FunctionCallParser + +from miles.utils.test_utils.mock_tools import SAMPLE_TOOLS, TwoTurnStub, execute_tool_call + + +class TestExecuteToolCall: + def test_execute_get_year(self): + result = asyncio.run(execute_tool_call("get_year", {})) + assert result == '{"year": 2026}' + + def test_execute_get_temperature(self): + result = asyncio.run(execute_tool_call("get_temperature", {"location": "Mars"})) + assert result == '{"temperature": -60}' + + +class TestApplyChatTemplateWithTools: + EXPECTED_PROMPT_WITHOUT_TOOLS = ( + "<|im_start|>user\n" "What's the weather in Paris?<|im_end|>\n" "<|im_start|>assistant\n" + ) + + EXPECTED_PROMPT_WITH_TOOLS = ( + "<|im_start|>system\n" + "# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + "\n" + '{"type": "function", "function": {"name": "get_year", "description": "Get current year", "parameters": {"type": "object", "properties": {}, "required": []}}}\n' + '{"type": "function", "function": {"name": "get_temperature", "description": "Get temperature for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}\n' + "\n\n" + "For each function call, return a json object with function name and arguments within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "<|im_end|>\n" + "<|im_start|>user\n" + "What's the weather in Paris?<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + @pytest.mark.parametrize( + "tools,expected", + [ + pytest.param(None, EXPECTED_PROMPT_WITHOUT_TOOLS, id="without_tools"), + pytest.param(SAMPLE_TOOLS, EXPECTED_PROMPT_WITH_TOOLS, id="with_tools"), + ], + ) + def test_apply_chat_template(self, tools, expected): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) + messages = [{"role": "user", "content": "What's the weather in Paris?"}] + + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools) + + assert prompt == expected + + +class TestSGLangFunctionCallParser: + """Test to demonstrate and ensure SGLang function call parser have features we need without breaking changes.""" + + @pytest.mark.parametrize( + "model_output,expected", + [ + pytest.param( + 'Let me check for you.\n\n{"name": "get_year", "arguments": {}}\n', + ( + "Let me check for you.", + [ToolCallItem(tool_index=-1, name="get_year", parameters="{}")], + ), + id="single_tool_call", + ), + pytest.param( + "I will get year and temperature.\n" + '\n{"name": "get_year", "arguments": {}}\n\n' + '\n{"name": "get_temperature", "arguments": {"location": "Shanghai"}}\n', + ( + "I will get year and temperature.", + [ + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Shanghai"}'), + ], + ), + id="multi_tool_calls", + ), + pytest.param( + "The weather is sunny today.", + ("The weather is sunny today.", []), + id="no_tool_call", + ), + pytest.param( + TwoTurnStub.FIRST_RESPONSE, + ( + "Let me get the year and temperature first.", + [ + ToolCallItem(tool_index=-1, name="get_year", parameters="{}"), + ToolCallItem(tool_index=-1, name="get_temperature", parameters='{"location": "Mars"}'), + ], + ), + id="multi_turn_first_response", + ), + ], + ) + def test_parse_non_stream(self, model_output, expected): + tools = TypeAdapter(list[Tool]).validate_python(SAMPLE_TOOLS) + parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25") + assert parser.parse_non_stream(model_output) == expected diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index c5c0838c5..9b6e69c29 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -126,6 +126,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, before_ray_job_submit=_launch_background, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_mimo_7B_mtp_only_grad.py b/tests/test_mimo_7B_mtp_only_grad.py index 97c76ace5..d90a2d7a7 100644 --- a/tests/test_mimo_7B_mtp_only_grad.py +++ b/tests/test_mimo_7B_mtp_only_grad.py @@ -135,6 +135,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_moonlight_16B_A3B.py b/tests/test_moonlight_16B_A3B.py index b1255982e..c35943ec1 100644 --- a/tests/test_moonlight_16B_A3B.py +++ b/tests/test_moonlight_16B_A3B.py @@ -113,6 +113,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index 15ca8ce5f..ae3c383ae 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -115,6 +115,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index dcdbd5834..4d7f034f6 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -120,6 +120,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index dcaaf5e1f..32b60f593 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -120,6 +120,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py index d55262cd0..8ce1988de 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async_short.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async_short.py @@ -118,6 +118,7 @@ def execute(): num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k_short.py b/tests/test_qwen2.5_0.5B_gsm8k_short.py index afbffbc56..87edf266f 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_short.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_short.py @@ -117,6 +117,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 3d19b48ce..3d4768e42 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -93,6 +93,7 @@ def execute(): train_args=train_args, num_gpus_per_node=2, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index 3d70f3e4c..fcd777288 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -95,6 +95,7 @@ def execute(): num_gpus_per_node=2 if FEW_GPU else 4, megatron_model_type=None, train_script="train_async.py", + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_0.6B_megatron_fsdp_align.py b/tests/test_qwen3_0.6B_megatron_fsdp_align.py index 1431d8c3d..b89a2f283 100644 --- a/tests/test_qwen3_0.6B_megatron_fsdp_align.py +++ b/tests/test_qwen3_0.6B_megatron_fsdp_align.py @@ -97,6 +97,7 @@ def execute(): train_args=train_args + (f"{fsdp_args}" f"--save-debug-rollout-data {debug_data_path} "), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -109,6 +110,7 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=None, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) U.execute_train( @@ -135,6 +137,7 @@ def execute(): "--debug-train-only " ), num_gpus_per_node=NUM_GPUS, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, megatron_model_type=MODEL_TYPE, ) diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/test_qwen3_0.6B_parallel_check.py index 44f5c42fa..d0ad283d1 100644 --- a/tests/test_qwen3_0.6B_parallel_check.py +++ b/tests/test_qwen3_0.6B_parallel_check.py @@ -95,6 +95,7 @@ def execute(): ), num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) # 8 GPU CPU 1 for num_gpus in [8, 4, 2]: @@ -124,6 +125,7 @@ def execute(): train_args=args, num_gpus_per_node=num_gpus, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) train_args += "--calculate-per-token-loss " diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index adff10804..b30eeed8e 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -139,6 +139,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/test_qwen3_4B_ckpt.py index 22fb2b5fc..0df4492e1 100644 --- a/tests/test_qwen3_4B_ckpt.py +++ b/tests/test_qwen3_4B_ckpt.py @@ -124,6 +124,7 @@ def execute(mode: str = ""): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_4B_fsdp_true_on_policy.py b/tests/test_qwen3_4B_fsdp_true_on_policy.py index 7c975c7cc..03ba4094e 100644 --- a/tests/test_qwen3_4B_fsdp_true_on_policy.py +++ b/tests/test_qwen3_4B_fsdp_true_on_policy.py @@ -95,6 +95,7 @@ def execute(): "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", "CUBLAS_WORKSPACE_CONFIG": ":4096:8", "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( diff --git a/tests/test_qwen3_4B_ppo.py b/tests/test_qwen3_4B_ppo.py index 962f610fa..d4c1ac273 100644 --- a/tests/test_qwen3_4B_ppo.py +++ b/tests/test_qwen3_4B_ppo.py @@ -122,6 +122,7 @@ def execute(): train_args=train_args, num_gpus_per_node=NUM_GPUS, megatron_model_type=MODEL_TYPE, + extra_env_vars={"MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1"}, ) diff --git a/tests/test_qwen3_vl_4B_fsdp.py b/tests/test_qwen3_vl_4B_fsdp.py index fbdffd237..bc4ef3293 100644 --- a/tests/test_qwen3_vl_4B_fsdp.py +++ b/tests/test_qwen3_vl_4B_fsdp.py @@ -92,6 +92,7 @@ def execute(): extra_env_vars = { "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "MILES_EXPERIMENTAL_ROLLOUT_REFACTOR": "1", } U.execute_train( diff --git a/tests/utils/sglang_stub.py b/tests/utils/sglang_stub.py new file mode 100644 index 000000000..6eece91f8 --- /dev/null +++ b/tests/utils/sglang_stub.py @@ -0,0 +1,44 @@ +import sys +import types + + +def _ensure_package(name: str) -> None: + module = sys.modules.get(name) + if module is None: + module = types.ModuleType(name) + module.__path__ = [] + sys.modules[name] = module + + +def install_sglang_stub() -> None: + _ensure_package("sglang") + _ensure_package("sglang.srt") + _ensure_package("sglang.srt.endpoints") + _ensure_package("sglang.srt.endpoints.openai") + _ensure_package("sglang.srt.entrypoints") + _ensure_package("sglang.srt.entrypoints.openai") + + # protocol_module = types.ModuleType("sglang.srt.endpoints.openai.protocol") + + class ChatCompletionMessageGenericParam: + def __init__(self, role: str, content: str | None = None, **kwargs): + self.role = role + self.content = content + for key, value in kwargs.items(): + setattr(self, key, value) + + def model_copy(self, update: dict): + data = self.__dict__.copy() + data.update(update) + return self.__class__(**data) + + class ChatCompletionMessageUserParam(ChatCompletionMessageGenericParam): + pass + + # ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam] + + # protocol_module.ChatCompletionMessageGenericParam = ChatCompletionMessageGenericParam + # protocol_module.ChatCompletionMessageUserParam = ChatCompletionMessageUserParam + # protocol_module.ChatCompletionMessageParam = ChatCompletionMessageParam + # sys.modules["sglang.srt.endpoints.openai.protocol"] = protocol_module + # sys.modules["sglang.srt.entrypoints.openai.protocol"] = protocol_module diff --git a/tests/utils/test_chat_message_utils.py b/tests/utils/test_chat_message_utils.py new file mode 100644 index 000000000..e721af984 --- /dev/null +++ b/tests/utils/test_chat_message_utils.py @@ -0,0 +1,41 @@ +import pytest + +from miles.utils.chat_message_utils import get_think_token_end, get_think_token_start, trim_think_tokens + + +def test_get_think_token_start_end(): + assert get_think_token_start("qwen3") == ("", 151667) + assert get_think_token_end("qwen3") == ("", 151668) + + +def test_trim_think_tokens_no_think(): + tokens = [1, 2, 3] + assert trim_think_tokens(tokens, "qwen3") == tokens + + +def test_trim_think_tokens_start_only(): + tokens = [1, 151667, 2, 3] + assert trim_think_tokens(tokens, "qwen3") == [1] + + +def test_trim_think_tokens_start_and_end(): + tokens = [1, 151667, 2, 151668, 3] + assert trim_think_tokens(tokens, "qwen3") == [1] + + +def test_trim_think_tokens_end_without_start(): + tokens = [1, 151668, 2] + with pytest.raises(ValueError, match="No think token start found"): + trim_think_tokens(tokens, "qwen3") + + +def test_trim_think_tokens_multiple_starts(): + tokens = [151667, 1, 151667] + with pytest.raises(ValueError, match="Multiple think token start found"): + trim_think_tokens(tokens, "qwen3") + + +def test_trim_think_tokens_multiple_ends(): + tokens = [151667, 1, 151668, 2, 151668] + with pytest.raises(ValueError, match="Multiple think token end found"): + trim_think_tokens(tokens, "qwen3")