Skip to content

Commit 1a49246

Browse files
committed
fix
1 parent efc99ca commit 1a49246

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

cookbook/legacy/grpo/gsm8k.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def main():
237237
name='sampler',
238238
ranks=list(range(MODEL_GPUS, NUM_GPUS)),
239239
device_type='GPU',
240-
gpus_per_worker=1,
240+
gpus_per_worker=4,
241241
),
242242
]
243243
if USE_MEGATRON:
@@ -249,7 +249,7 @@ def main():
249249
world_size=MODEL_GPUS, dp_size=MODEL_GPUS,
250250
)
251251
sampler_mesh = DeviceMesh.from_sizes(
252-
world_size=SAMPLER_GPUS, dp_size=4
252+
world_size=SAMPLER_GPUS, tp_size=4
253253
)
254254
twinkle.initialize(
255255
mode='ray',

src/twinkle/checkpoint_engine/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def sync_weights(self, merge_and_sync=True):
8181
model_metadata = self.model.prepare_checkpoint_engine([True] + [False]*(self.model.device_mesh.world_size -1))
8282
self.sampler.prepare_checkpoint_engine(False)
8383
model_kwargs, sampler_kwargs = self.backend_cls.build_topology(
84-
self.model.device_mesh.world_size, self.sampler.device_mesh.world_size, [model_metadata],
84+
self.model.device_mesh.world_size, self.sampler.device_mesh.data_world_size, [model_metadata],
8585
)
8686
# Launch both init calls concurrently — TCPStore server (model rank 0)
8787
# blocks until all clients (sampler ranks) connect, so these MUST NOT

0 commit comments

Comments
 (0)