-
Notifications
You must be signed in to change notification settings - Fork 80
prefetch weights while waiting for pending requests to complete #728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@JenniferWang has exported this pull request. If you are a Meta employee, you can view the originating Diff in D91092833. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #728 +/- ##
==========================================
- Coverage 78.33% 68.73% -9.61%
==========================================
Files 36 42 +6
Lines 4209 4455 +246
==========================================
- Hits 3297 3062 -235
- Misses 912 1393 +481 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary: ## tl;dr Adds ForgeMonarchExecutor and ForgeWorkerWrapper to enable weight synchronization via TorchStore for RL training loops (e.g., GRPO). Specifically, the diff serialize the TochStore controller Actor to MonarchExecutor for sharing the controller. ## Test Plan [-] Weight update correctness test: `TORCHSTORE_RDMA_ENABLED=0 PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check --config tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml` [-] Local host: `python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml` [-] Remote host: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-cve6ce%3APRODUCTION%3A0/logs?attempt=0&taskGroups=trainer%3A0%2Cref_model_0%3A0%2Cgenerator_0%3A0%2Cclient%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D ## Next Steps [ ] implement the prefetch logic & shared memory [ ] Add metric similar to generator v0 [ ] Perf/Throughput testing compared to generator v0 Differential Revision: D90775552
Summary: ## tl;dr Add tracer in v1 to log perf metrics to wandb ## V0 vs V1 Metrics Parity Comparison | Category | v0 Metric | v1 Metric | Parity | |----------|-----------|-----------|--------| | **Generate - Request Count** | `generator/generate/count_requests` (SUM) | `generator/generate/count_requests` (SUM) | ✅ Same | | **Generate - Completion Count** | `generator/generate/count_sequences_completed` (SUM) | `generator/generate/count_sequences_completed` (SUM) | ✅ Same | | **Generate - E2E Timing** | `generator_perf/generate/*` (Tracer, GPU) | `generator_perf/generate/*` (Tracer, GPU) | ✅ Same | | **Update - Pending Requests** | `generator_perf/update_weights/sum_pending_gen_requests` (SUM) | N/A - AsyncLLM handles internally |⚠️ Skip (by design) | | **Update - Wait for Generation** | `generator_perf/update_weights/avg_waiting_for_generation_duration_s` (MEAN) | `generator_perf/update_weights/pause_generation_duration_s` (MEAN) | ✅ Equivalent - renamed for clarity | | **Update - Fetch Weights** | `generator_perf/update_weights/wait_fetch_weights` (MEAN) | `generator_perf/update_weights/worker_load_weights_duration_s` (MEAN) | ✅ Equivalent - renamed for clarity | | **Worker - Update Timing** | `generator_perf/update_weights/generator_worker_update/*` (trace, GPU) | `generator_perf/update_weights/generator_worker_update/*` (trace, GPU) | ✅ Same | ## Test Plan Main GRPO app: `python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml` ``` wandb: Run `wandb offline` to turn off syncing. wandb: Syncing run drawn-waterfall-686 wandb: ⭐️ View project at https://meta.wandb.io/jiyue/grpo-training wandb: 🚀 View run at https://meta.wandb.io/jiyue/grpo-training/runs/6pltx38p wandb: Detected [openai] in use. .... rvability.metric_actors.GlobalLoggingActor global_logger>] === [global_reduce] - METRICS STEP 1 === ... generator/generate/count_requests: 13.0 generator/generate/count_sequences_completed: 96.0 generator_perf/generate/total_duration_avg_s: 3.6518315022786463 generator_perf/generate/total_duration_max_s: 9.2080615234375 generator_perf/update_weights/pause_generation_duration_s: 2.8634108749683946 generator_perf/update_weights/resume_generation_duration_s: 1.918897032737732e-05 generator_perf/update_weights/worker_load_weights_duration_s: 3.506648204056546 ... ``` Make sure integration tests that do not initialize the tracer still works `pytest tests/integration_tests/test_generator_lifecycle.py -v -s` ## Next Steps [ ] implement the prefetch logic & shared memory [-] Add metric similar to generator v0 [ ] Perf/Throughput testing compared to generator v0 Reviewed By: allenwang28 Differential Revision: D91038187
Summary: Feature parity with v0: allow prefetching weights while waiting for the pending requests to finish. ## Test Plan Introduced a benchmark that simulates the on-going requests with actual weight sync logic. Reference Group (V0) ``` ================================================================================ WEIGHT SYNC BENCHMARK RESULTS ================================================================================ Model: Qwen/Qwen3-8B Model size: 15.26 GB Iterations: 3 Prefetch enabled: False -------------------------------------------------------------------------------- Metric Time (s) Throughput (GB/s) -------------------------------------------------------------------------------- Avg push_weights 5.102 s 2.99 GB/s Avg update_weights 43.738 s 0.35 GB/s Avg total (push + update) 48.840 s ================================================================================ ================================================================================ WEIGHT SYNC BENCHMARK RESULTS ================================================================================ Model: Qwen/Qwen3-8B Model size: 15.26 GB Iterations: 3 Prefetch enabled: True Fetcher procs: 8 -------------------------------------------------------------------------------- Metric Time (s) Throughput (GB/s) -------------------------------------------------------------------------------- Avg push_weights 5.208 s 2.93 GB/s Avg update_weights 29.602 s 0.52 GB/s Avg total (push + update) 34.810 s ================================================================================ ``` Test Group (V1) ``` ================================================================================ WEIGHT SYNC BENCHMARK RESULTS ================================================================================ Model: Qwen/Qwen3-8B Model size: 15.26 GB Iterations: 3 Prefetch enabled: False -------------------------------------------------------------------------------- Metric Time (s) Throughput (GB/s) -------------------------------------------------------------------------------- Avg push_weights 5.070 s 3.01 GB/s Avg update_weights 39.974 s 0.38 GB/s Avg total (push + update) 45.044 s ================================================================================ ================================================================================ WEIGHT SYNC BENCHMARK RESULTS ================================================================================ Model: Qwen/Qwen3-8B Model size: 15.26 GB Iterations: 3 Prefetch enabled: True Fetcher procs: 8 -------------------------------------------------------------------------------- Metric Time (s) Throughput (GB/s) -------------------------------------------------------------------------------- Avg push_weights 5.055 s 3.02 GB/s Avg update_weights 28.730 s 0.53 GB/s Avg total (push + update) 33.784 s ================================================================================ ``` ## Next Steps [-] implement the prefetch logic & shared memory [-] Add metric similar to generator v0 [ ] Perf/Throughput testing compared to generator v0 Differential Revision: D91092833
0c99d56 to
969b8ab
Compare
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't tell if this is supposed to be a stacked diff? Contains more than just prefetch information.
| """TitanTrainer with weight modification capabilities for benchmarking.""" | ||
|
|
||
| @endpoint | ||
| async def modify_weights(self, scale: float = 1.1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit: do we need to parameterize this? Can't we just 1) assume it's a floating point and 2) arbitrarily add or scale by X ?
| logger.info( | ||
| "[ForgeMonarchExecutor] Deserializing TorchStore Controller from environment..." | ||
| ) | ||
| self.torchstore_controller = cloudpickle.loads( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😅
| model: str | ||
| iterations: int | ||
| prefetch_enabled: bool | ||
| n_fetcher_procs: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we test how this parameter affects throughput?
Summary:
Feature parity with v0: allow prefetching weights while waiting for the pending requests to finish.
Test Plan
Introduced a benchmark that simulates the on-going requests with actual weight sync logic.
Reference Group (V0)
Test Group (V1)
Next Steps
[-] implement the prefetch logic & shared memory
[-] Add metric similar to generator v0
[ ] Perf/Throughput testing compared to generator v0
Differential Revision: D91092833