diff --git a/src/forge/actors/vllm/v1/__init__.py b/src/forge/actors/vllm/v1/__init__.py index 4fcddd264..ed3c8fe4c 100644 --- a/src/forge/actors/vllm/v1/__init__.py +++ b/src/forge/actors/vllm/v1/__init__.py @@ -25,6 +25,14 @@ def __getattr__(name): from forge.actors.vllm.v1.monarch_executor import WorkerWrapper return WorkerWrapper + if name == "ForgeMonarchExecutor": + from forge.actors.vllm.v1.forge_executor import ForgeMonarchExecutor + + return ForgeMonarchExecutor + if name == "ForgeWorkerWrapper": + from forge.actors.vllm.v1.forge_executor import ForgeWorkerWrapper + + return ForgeWorkerWrapper raise AttributeError(f"module {__name__!r} has no attribute {name!r}") @@ -32,4 +40,6 @@ def __getattr__(name): "Generator", "MonarchExecutor", "WorkerWrapper", + "ForgeMonarchExecutor", + "ForgeWorkerWrapper", ] diff --git a/src/forge/actors/vllm/v1/forge_executor.py b/src/forge/actors/vllm/v1/forge_executor.py new file mode 100644 index 000000000..75adff46d --- /dev/null +++ b/src/forge/actors/vllm/v1/forge_executor.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Forge-specific MonarchExecutor with TorchStore weight sync. + +This module extends the upstream-compatible MonarchExecutor with TorchStore +integration for weight synchronization in RL training loops. It provides: + +- ForgeWorkerWrapper: Extends WorkerWrapper with TorchStore weight loading +- ForgeMonarchExecutor: Extends MonarchExecutor with TorchStore Controller handling + +Use this executor when you need weight updates from TorchStore (e.g., GRPO training). +For inference-only workloads, use the base MonarchExecutor directly. + +TODO: Add shared memory weight prefetch support (prefetch_weights_to_shm, n_fetcher_procs) + as in v0 Generator for faster weight loading. +""" + +from __future__ import annotations + +import asyncio +import base64 +import logging +import os +from typing import Optional + +import cloudpickle +from forge.actors._torchstore_utils import extract_param_name, get_param_prefix +from forge.actors.vllm.v1.monarch_executor import MonarchExecutor, WorkerWrapper +from monarch.actor import endpoint +from torchstore.client import LocalClient + +logger = logging.getLogger(__name__) + + +class ForgeWorkerWrapper(WorkerWrapper): + """Worker wrapper with TorchStore weight sync capabilities.""" + + def __init__(self, vllm_config): + super().__init__(vllm_config) + self._torchstore_controller = None + self._torchstore_client: Optional[LocalClient] = None + + @endpoint + def set_torchstore_controller(self, controller) -> None: + """Store TorchStore Controller reference for weight updates. + + Workers run in a subprocess with a different _controller_controller, + so they can't find the Controller via get_or_spawn_controller. + The Controller reference is passed explicitly from ForgeMonarchExecutor. + """ + self._torchstore_controller = controller + self._torchstore_client = None # Reset cached client + + @endpoint + def update_weights(self, version: int) -> int: + """Load weights directly from torchstore. + + Args: + version: Policy version to load from torchstore + + Returns: + Number of parameters loaded + """ + return asyncio.run(self._load_from_torchstore(version)) + + async def _get_torchstore_client(self) -> LocalClient: + """Get or create a LocalClient using the passed Controller reference. + + Workers can't use ts.client() directly because they're in a subprocess + with a different _controller_controller. Instead, we create a LocalClient + using the Controller reference passed from ForgeMonarchExecutor. + """ + if self._torchstore_client is not None: + return self._torchstore_client + + if self._torchstore_controller is None: + raise RuntimeError( + "TorchStore Controller not set. " + "ForgeMonarchExecutor must call set_torchstore_controller before weight updates." + ) + + strategy = await self._torchstore_controller.get_controller_strategy.call_one() + self._torchstore_client = LocalClient( + controller=self._torchstore_controller, + strategy=strategy, + ) + return self._torchstore_client + + async def _load_from_torchstore(self, version: int) -> int: + """Async helper to load from torchstore using the passed Controller.""" + client = await self._get_torchstore_client() + prefix = get_param_prefix(version) + matching_keys = await client.keys(prefix) + model = self.worker.model_runner.model + loaded_count = 0 + for key in matching_keys: + name = extract_param_name(key) + param = await client.get(key) + model.load_weights([(name, param.cuda())]) + del param + loaded_count += 1 + return loaded_count + + @endpoint + def save_model_params(self): + """Save model parameters before weight update, used for testing purposes only.""" + logger.info("[WorkerWrapper] save model parameters for testing.") + if not hasattr(self, "_test_prev_params"): + self._test_prev_params = {} + for name, param in self.worker.model_runner.model.named_parameters(): + self._test_prev_params[name] = param.detach().cpu() + logger.info( + "[WorkerWrapper] finished saving model parameters, len = %d", + len(self._test_prev_params), + ) + + @endpoint + def validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[WorkerWrapper] start validating model parameters.") + if not hasattr(self, "_test_prev_params"): + self._test_prev_params = {} + return validate_fn( + self._test_prev_params, self.worker.model_runner.model, logger + ) + + +class ForgeMonarchExecutor(MonarchExecutor): + """MonarchExecutor with TorchStore integration for weight sync. + + Extends the base MonarchExecutor to: + - Deserialize TorchStore Controller from environment + - Pass Controller to workers for direct weight loading + - Use ForgeWorkerWrapper instead of base WorkerWrapper + """ + + worker_class = ForgeWorkerWrapper + + def _init_executor(self) -> None: + """Initialize executor and deserialize TorchStore Controller.""" + super()._init_executor() + + controller_str = os.environ.get("VLLM_TORCHSTORE_CONTROLLER") + if controller_str: + logger.info( + "[ForgeMonarchExecutor] Deserializing TorchStore Controller from environment..." + ) + self.torchstore_controller = cloudpickle.loads( + base64.b64decode(controller_str) + ) + logger.info( + f"[ForgeMonarchExecutor] TorchStore Controller deserialized: {self.torchstore_controller}" + ) + self.workers.set_torchstore_controller.call( + self.torchstore_controller + ).get() + + else: + self.torchstore_controller = None + logger.warning( + "[ForgeMonarchExecutor] No TorchStore Controller found in environment. " + "Weight updates via torchstore will not work." + ) diff --git a/src/forge/actors/vllm/v1/generator.py b/src/forge/actors/vllm/v1/generator.py index acfa60a2e..22fea884e 100644 --- a/src/forge/actors/vllm/v1/generator.py +++ b/src/forge/actors/vllm/v1/generator.py @@ -21,6 +21,7 @@ from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt from monarch.actor import endpoint, this_host +from torchstore.api import _controller as get_torchstore_controller from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.llm import UsageContext from vllm.outputs import RequestOutput @@ -197,14 +198,21 @@ async def setup(self, host_mesh, worker_registry, gpu_ids: list[str]): ).decode("utf-8") os.environ["VLLM_MONARCH_WORKER_REGISTRY"] = serialized_registry + # Serialize TorchStore Controller reference for workers to access torchstore + torchstore_controller = await get_torchstore_controller() + serialized_controller = base64.b64encode( + cloudpickle.dumps(torchstore_controller) + ).decode("utf-8") + os.environ["VLLM_TORCHSTORE_CONTROLLER"] = serialized_controller + # Force 'spawn' multiprocessing method for Monarch actors. # This follows vLLM's Ray integration pattern. os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - # Set the executor backend to MonarchExecutor via string path + # Set the executor backend to ForgeMonarchExecutor via string path # This avoids import deadlock when vLLM spawns EngineCore subprocess self.vllm_config.parallel_config.distributed_executor_backend = ( - "forge.actors.vllm.v1.monarch_executor.MonarchExecutor" + "forge.actors.vllm.v1.forge_executor.ForgeMonarchExecutor" ) from vllm.v1.executor.abstract import Executor @@ -309,6 +317,61 @@ async def shutdown(cls, actor): logger.info("shutdown() complete") + @endpoint + async def update_weights( + self, + version: Optional[int] = None, + ) -> None: + """Update weights on the generator from torchstore. + + This method: + 1. Pauses generation and waits for in-flight requests to complete + 2. Updates weights on workers from torchstore + 3. Resumes generation + + Note: This is NOT the standard vLLM weight update approach. vLLM typically + uses `collective_rpc` on EngineClient, which internally routes calls to + workers via the executor. However, `collective_rpc` uses msgspec/msgpack + serialization which does not support arbitrary Python objects by default + (only with VLLM_ALLOW_INSECURE_SERIALIZATION=1). This makes it difficult to + pass complex objects like torchstore storage handles. Instead, we use a + monarch-native approach where the Generator actor directly calls the worker + mesh (`self.workers.update_weights`) via Monarch RPC, which uses cloudpickle + and natively supports Monarch actor references for torchstore integration. + + Args: + version: Policy version to load from torchstore + """ + if self.llm is None: + raise RuntimeError("Generator not initialized. Call setup() first.") + + logger.info(f"Starting weight update to v{version}") + + await self.llm.pause_generation( + wait_for_inflight_requests=True, clear_cache=True + ) + + try: + await self.workers.update_weights.call(version) + self.generator_version = version + logger.info(f"Updated weights from torchstore v{version}") + finally: + await self.llm.resume_generation() + + logger.info(f"Weight update complete, now v{version}") + + @endpoint + async def save_model_params(self): + """Save model parameters before weight update, used for testing purposes only.""" + logger.info("save model parameters for testing.") + await self.workers.save_model_params.call() + + @endpoint + async def validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("start validating model parameters.") + return await self.workers.validate_model_params.call(validate_fn) + def _to_completions( self, request_output: RequestOutput, prompt: str ) -> list[Completion]: