Skip to content

Commit 86539cc

Browse files
authored
Merge pull request #41 from modelscope/fix_npu_grpo
Fix npu grpo
2 parents 1dfcf43 + ad616ab commit 86539cc

File tree

4 files changed

+138
-13
lines changed

4 files changed

+138
-13
lines changed

cookbook/grpo/lora_npu.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
from twinkle.model import TransformersModel
1010
from twinkle.reward import MathReward
1111
from twinkle.sampler import VLLMSampler, TorchSampler
12-
from twinkle.sampler.types import SamplingParams
12+
from twinkle.sampler.types import SamplingParams, SampleResponse
1313
from twinkle.weight_loader import NativeLoader
1414
from twinkle.rl import compute_advantages
1515

1616
# Environment variable setup
1717
os.environ.setdefault('TRUST_REMOTE_CODE', '1')
1818
os.environ.setdefault('TWINKLE_SEED', '42')
1919
os.environ.setdefault('TWINKLE_FULL_DETERMINISM', '1')
20+
os.environ.setdefault('RAY_TMPDIR', os.path.expanduser('~/tmp/ray'))
2021

2122
# Training configuration
2223
use_ref_model = os.environ.get('TWINKLE_USE_REF_MODEL', '1') != '0'
@@ -129,6 +130,29 @@ def get_sampling_params(eos_token_ids) -> SamplingParams:
129130
)
130131

131132

133+
def build_trajectories_from_sample_response(sample_response: SampleResponse, batch_list, tokenizer):
134+
"""Convert sampler output into GRPO trajectories."""
135+
if not sample_response or not getattr(sample_response, 'sequences', None):
136+
return []
137+
if not batch_list:
138+
return []
139+
140+
trajectories = []
141+
for i, seq in enumerate(sample_response.sequences):
142+
src_batch = batch_list[i % len(batch_list)]
143+
src_messages = [dict(msg) for msg in src_batch.get('messages', [])]
144+
if src_messages and src_messages[-1].get('role') == 'assistant':
145+
# Remove reference answer and append sampled assistant reply.
146+
src_messages = src_messages[:-1]
147+
148+
response_text = tokenizer.decode(seq.tokens, skip_special_tokens=True) if tokenizer is not None else ''
149+
trajectories.append({
150+
'messages': src_messages + [{'role': 'assistant', 'content': response_text}],
151+
'user_data': list(src_batch.get('user_data', [])),
152+
})
153+
return trajectories
154+
155+
132156
def debug_print_rollout(step, trajectories, ground_truths, rewards=None):
133157
"""Debug helper that prints rollout intermediates (sampling, rewards, etc.).
134158
@@ -182,6 +206,19 @@ def debug_print_rollout(step, trajectories, ground_truths, rewards=None):
182206
)
183207

184208

209+
def _collect_sample_responses(results):
210+
"""Custom collect function to merge multiple SampleResponse objects."""
211+
if not results:
212+
return SampleResponse(sequences=[])
213+
if len(results) == 1:
214+
return results[0]
215+
all_sequences = []
216+
for resp in results:
217+
if resp is not None and hasattr(resp, 'sequences'):
218+
all_sequences.extend(resp.sequences)
219+
return SampleResponse(sequences=all_sequences)
220+
221+
185222
@remote_class()
186223
class ActorGroup:
187224

@@ -226,7 +263,7 @@ def __init__(self, engine_args=None, lora_config=None, adapter_name=None, **kwar
226263
self.adapter_name = adapter_name
227264
self.lora_config = lora_config
228265

229-
@remote_function(collect='flatten')
266+
@remote_function(collect=_collect_sample_responses)
230267
def sample(self, batch, sampling_params: SamplingParams = None):
231268
return self.sampler.sample(batch, sampling_params=sampling_params, adapter_name=self.adapter_name)
232269

@@ -293,6 +330,11 @@ def train():
293330
)
294331

295332
eos_token_ids = get_eos_token_ids()
333+
try:
334+
from transformers import AutoTokenizer
335+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
336+
except Exception:
337+
tokenizer = None
296338

297339
engine_args = {
298340
'model': model_path,
@@ -339,13 +381,18 @@ def train():
339381
batch_list = [batch]
340382
else:
341383
batch_list = list(batch)
342-
ground_truths = batch_list.copy()
343-
344384
sampling_params = get_sampling_params(eos_token_ids)
345385

346-
trajectories = actor_group.sample(batch_list, sampling_params)
347-
if callable(trajectories):
348-
trajectories = trajectories()
386+
sample_response = actor_group.sample(batch_list, sampling_params)
387+
if callable(sample_response):
388+
sample_response = sample_response()
389+
trajectories = build_trajectories_from_sample_response(sample_response, batch_list, tokenizer)
390+
if not trajectories:
391+
print(f'[step {step}] empty sampled trajectories, skip.', flush=True)
392+
continue
393+
394+
# Expand ground truths to align with sampled trajectory count.
395+
ground_truths = [batch_list[i % len(batch_list)] for i in range(len(trajectories))]
349396

350397
ref_logits = None
351398
if use_ref_model:
@@ -357,14 +404,19 @@ def train():
357404
else:
358405
ref_logits = ref_outputs['logits'] if isinstance(ref_outputs, dict) else ref_outputs.logits
359406

360-
rewards = reward.calculate(trajectories, ground_truths)
407+
rewards = reward(trajectories, ground_truths)
361408
if callable(rewards):
362409
rewards = rewards()
363410

364-
# Updated: compute advantages from rewards and store in trajectory
365-
advantages = compute_advantages(rewards, num_generations=num_generations)
411+
effective_num_generations = num_generations if len(rewards) % num_generations == 0 else 1
412+
scale = 'group' if effective_num_generations > 1 else 'batch'
413+
advantages = compute_advantages(
414+
rewards,
415+
num_generations=effective_num_generations,
416+
scale=scale,
417+
)
366418
for trajectory, advantage in zip(trajectories, advantages.tolist()):
367-
trajectory['advantages'] = advantage
419+
trajectory['advantages'] = float(advantage)
368420

369421
# Debug: print reward statistics (enable via TWINKLE_DEBUG=1)
370422
debug_print_rollout(step, trajectories, ground_truths, rewards=rewards)
@@ -383,5 +435,6 @@ def train():
383435
break
384436

385437

438+
386439
if __name__ == '__main__':
387440
train()

src/twinkle/rl/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,35 @@
22
from .base import Advantage
33
from .grpo import GRPOAdvantage
44
from .rloo import RLOOAdvantage
5+
6+
7+
# TODO: Temporary helpers added to unblock cookbook/grpo examples.
8+
# Each call creates a new Advantage instance, not suitable for production.
9+
# Remove once the framework provides a proper advantage computation API.
10+
def compute_advantages(rewards, num_generations=1, scale='group', **kwargs):
11+
"""Backward-compatible helper for GRPO advantage computation."""
12+
return GRPOAdvantage()(
13+
rewards=rewards,
14+
num_generations=num_generations,
15+
scale=scale,
16+
**kwargs,
17+
)
18+
19+
20+
def compute_advantages_rloo(rewards, num_generations=1, scale='group', **kwargs):
21+
"""Backward-compatible helper for RLOO advantage computation."""
22+
return RLOOAdvantage()(
23+
rewards=rewards,
24+
num_generations=num_generations,
25+
scale=scale,
26+
**kwargs,
27+
)
28+
29+
30+
__all__ = [
31+
'Advantage',
32+
'GRPOAdvantage',
33+
'RLOOAdvantage',
34+
'compute_advantages',
35+
'compute_advantages_rloo',
36+
]

src/twinkle/sampler/vllm_engine.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,16 @@ def _create_engine(self):
160160

161161
logger.info(f"VLLMEngine initialized: model={self.model_id}")
162162
return engine
163-
163+
164+
def shutdown(self):
165+
"""Shutdown the underlying vLLM AsyncLLM engine."""
166+
if hasattr(self, 'engine') and self.engine is not None:
167+
try:
168+
self.engine.shutdown()
169+
logger.info("VLLMEngine shutdown completed.")
170+
except Exception as e:
171+
logger.warning(f"VLLMEngine shutdown error: {e}")
172+
164173
async def get_tokenizer(self):
165174
"""Get the tokenizer asynchronously."""
166175
if self._tokenizer is None:

src/twinkle/sampler/vllm_sampler.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- Results are collected via collect='flatten' (merged into single list)
2121
"""
2222
import asyncio
23+
import atexit
2324
import logging
2425
import os
2526
import threading
@@ -137,7 +138,10 @@ def __init__(
137138
)
138139

139140
VLLMLoraWeights().patch(self)
140-
141+
142+
self._shutdown_called = False
143+
atexit.register(self.shutdown)
144+
141145
def _run_event_loop(self):
142146
"""Run the event loop in background thread."""
143147
asyncio.set_event_loop(self._async_loop)
@@ -409,3 +413,30 @@ def wake_up(self, tags: List[str] = None, reload_weights: bool = False) -> None:
409413
Required after level 2 sleep which discards weights.
410414
"""
411415
self._run_in_loop(self.engine.wake_up(tags=tags, reload_weights=reload_weights))
416+
417+
def shutdown(self):
418+
"""Gracefully shutdown the vLLM engine and background event loop.
419+
420+
Registered via atexit so it runs automatically on process exit,
421+
before GC destroys objects in unpredictable order. Safe to call
422+
multiple times (idempotent).
423+
"""
424+
if self._shutdown_called:
425+
return
426+
self._shutdown_called = True
427+
428+
# 1. Shutdown vLLM engine (stops EngineCore process and output_handler)
429+
try:
430+
if hasattr(self, 'engine') and self.engine is not None:
431+
self.engine.shutdown()
432+
except Exception as e:
433+
logger.warning(f"VLLMSampler engine shutdown error: {e}")
434+
435+
# 2. Stop the background event loop and join thread
436+
try:
437+
if hasattr(self, '_async_loop') and self._async_loop.is_running():
438+
self._async_loop.call_soon_threadsafe(self._async_loop.stop)
439+
if hasattr(self, '_async_thread') and self._async_thread.is_alive():
440+
self._async_thread.join(timeout=5)
441+
except Exception as e:
442+
logger.warning(f"VLLMSampler event loop shutdown error: {e}")

0 commit comments

Comments
 (0)