Skip to content

Commit 68ccd6c

Browse files
committed
fix
1 parent 91e8b49 commit 68ccd6c

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

cookbook/legacy/grpo/lora_backup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444

4545
# SwanLab is optional - only used if SWANLAB_API_KEY is set
4646
USE_SWANLAB = True
47-
os.environ['SWANLAB_API_KEY'] = '3hVJrk0veNB2NCm72UdJg'
4847
if USE_SWANLAB:
4948
import swanlab
5049
if USE_SWANLAB:

src/twinkle/model/transformers/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def calculate_loss(self, **kwargs):
421421
loss_value, counts = result
422422
else:
423423
loss_value = result
424-
counts = torch.tensor(1, device=loss_value.device)
424+
counts = torch.tensor(0, device=loss_value.device)
425425
optimizer_config = self.optimizer_group[adapter_name]
426426
optimizer_config.num_tokens += counts.item()
427427
if self.sp_strategy is not None and 'labels' in inputs:

src/twinkle/sampler/vllm_sampler/vllm_engine.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494

9595
# Tokenizer is lazy loaded via get_tokenizer()
9696
self._tokenizer = None
97+
breakpoint()
9798

9899
def _create_engine(self):
99100
"""Create and return the vLLM engine."""
@@ -258,6 +259,17 @@ async def sample(
258259
lora_request = await self._get_or_load_lora(adapter_path, adapter_user_id)
259260

260261
# Generate
262+
if lora_request is None:
263+
from vllm.lora.request import LoRARequest
264+
from twinkle.sampler.vllm_sampler.vllm_worker_extension import (
265+
VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH,
266+
)
267+
lora_request = LoRARequest(
268+
lora_name=VLLM_LORA_NAME,
269+
lora_int_id=VLLM_LORA_INT_ID,
270+
lora_path=VLLM_LORA_PATH,
271+
)
272+
261273
generator = self.engine.generate(
262274
prompt=prompt,
263275
sampling_params=vllm_params,

0 commit comments

Comments
 (0)