Skip to content

Commit f2e26dd

Browse files
committed
fix
1 parent c2cb1dd commit f2e26dd

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

cookbook/rl/dpo_full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
logger = get_logger()
6868

6969
# ── Configuration ─────────────────────────────────────────────────────────────
70-
USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 1))
70+
USE_MEGATRON = int(os.environ.get('USE_MEGATRON', 0))
7171
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B')
7272
DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
7373

cookbook/rl/dpo_lora.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,10 @@ def main():
147147
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, pp_size=2)
148148
ModelClass = MegatronModel
149149
else:
150-
# Transformers: fsdp=4, dp=2
150+
# Transformers: dp_size=8
151+
# FSDP2 forward_only & forward has problems with `with unwrapped_model.disable_adapter()`
151152
from twinkle.model import TransformersModel
152-
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=4, fsdp_size=2)
153+
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=8)
153154
ModelClass = TransformersModel
154155

155156
twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups)

0 commit comments

Comments
 (0)