Skip to content

Commit b1b0e71

Browse files
committed
fix VLLM->vLLM
1 parent 6d20015 commit b1b0e71

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

cookbook/client/twinkle/transformer/grpo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from twinkle_client.dataloader import DataLoader
3737
from twinkle_client.dataset import Dataset
3838
from twinkle_client.model import MultiLoraTransformersModel
39-
from twinkle_client.sampler import VLLMSampler
39+
from twinkle_client.sampler import vLLMSampler
4040

4141
logger = get_logger()
4242

@@ -116,7 +116,7 @@ def train():
116116
model.set_template('Template', model_id=MODEL_ID)
117117

118118
# Step 4: Configure the sampler
119-
sampler = VLLMSampler(model_id=MODEL_ID)
119+
sampler = vLLMSampler(model_id=MODEL_ID)
120120
sampler.set_template('Template', model_id=MODEL_ID)
121121

122122
# Step 5: Setup metrics and advantage function

cookbook/client/twinkle/transformer/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from twinkle import get_logger
1919
from twinkle_client import init_twinkle_client
20-
from twinkle_client.sampler import VLLMSampler
20+
from twinkle_client.sampler import vLLMSampler
2121

2222
logger = get_logger()
2323

@@ -39,7 +39,7 @@ def sample():
3939
)
4040

4141
# Step 3: Create the sampler client pointing to the model on the server
42-
sampler = VLLMSampler(model_id=MODEL_ID)
42+
sampler = vLLMSampler(model_id=MODEL_ID)
4343

4444
# Step 4: Set the chat template so the sampler can encode Trajectory inputs
4545
sampler.set_template('Template', model_id=MODEL_ID)

src/twinkle/metric/loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,5 @@ def calculate(self):
6565
if avg_loss is not None:
6666
results['loss'] = f'{avg_loss:.4f}'
6767
if grad_norm > 0:
68-
results['grad_norm'] = f'{grad_norm:.2f}'
69-
return results
68+
results['grad_norm'] = f'{grad_norm:.6f}'
69+
return results

src/twinkle/server/twinkle/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any],
163163

164164
# Initialize sampler based on type
165165
if sampler_type == 'vllm':
166-
from twinkle.sampler import VLLMSampler
166+
from twinkle.sampler import vLLMSampler
167167
sampler_kwargs = engine_args or {}
168-
self.sampler = VLLMSampler(
168+
self.sampler = vLLMSampler(
169169
model_id=model_id,
170170
engine_args=sampler_kwargs,
171171
device_mesh=self.device_mesh,

tests/sampler/test_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def test_set_template(self):
239239
# =============================================================================
240240

241241
@pytest.mark.skip(reason="Requires model and GPU")
242-
class TestVLLMSamplerIntegration:
242+
class TestvLLMSamplerIntegration:
243243
"""Integration tests for vLLMSampler."""
244244

245245
def test_sample_with_trajectory(self):

0 commit comments

Comments
 (0)