Skip to content

Add rl example#45

Merged
Yunnglin merged 12 commits intodevfrom
add_rl_example
Feb 8, 2026
Merged

Add rl example#45
Yunnglin merged 12 commits intodevfrom
add_rl_example

Conversation

@Yunnglin
Copy link
Copy Markdown
Collaborator

@Yunnglin Yunnglin commented Feb 8, 2026

No description provided.

Copilot AI review requested due to automatic review settings February 8, 2026 15:13
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Yunnglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive reinforcement learning examples, significantly refactoring the Twinkle sampler server to enhance its API, improve LoRA adapter management, and support diverse sampler backends. It also extends client-side sampling capabilities and refines dataset serialization for greater flexibility. Additionally, the server setup documentation has been updated to provide clearer guidance on Ray cluster configurations, ensuring a more robust and user-friendly experience for deploying and utilizing the framework's advanced features.

Highlights

  • New Reinforcement Learning Examples: Introduced two new Group Relative Policy Optimization (GRPO) training examples: one using the Tinker-compatible client API and another using the Twinkle client API. These examples demonstrate how to perform RL training, including periodic weight saving for samplers, sampling completions, and computing rewards/advantages.
  • Twinkle Sampler Server Refactor: The Twinkle sampler server (src/twinkle/server/twinkle/sampler.py) has undergone a significant refactor. It now utilizes Pydantic models for API requests/responses, inherits from AdapterManagerMixin for robust LoRA adapter lifecycle management (including inactivity timeouts), and supports both VLLM and Torch sampler backends.
  • Enhanced Client-Side Sampling: The client-side sampling methods in client_tools/client_generator.py and src/twinkle_client/sampler/vllm_sampler.py have been extended to support adapter_uri for specifying LoRA adapter paths and num_samples for generating multiple completions per prompt, offering greater flexibility in inference.
  • Improved Dataset Serialization: The serialization logic for DatasetMeta objects has been enhanced to correctly handle various data_slice types (e.g., range, list, tuple) when communicating over HTTP, enabling more flexible dataset definitions.
  • Updated Server Configuration and Documentation: The server_config.yaml now includes an explicit sampler service configuration for vLLM, and the server documentation has been significantly expanded with detailed instructions on configuring Ray clusters and understanding node ranks for different components.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • client_tools/client_generator.py
    • Updated server_url construction in __init__ to use /samplers/{model_id} instead of /models/{model_id}.
    • Modified sample method signature to include adapter_uri and num_samples parameters.
    • Changed sample method return type from SampleResponse to Dict[str, Any].
    • Adjusted http_post json_data in sample to pass adapter_uri and num_samples.
  • cookbook/client/tinker/transformer/grpo.py
    • Added a new GRPO (Group Relative Policy Optimization) training example using the Tinker-compatible client API.
    • Demonstrates save_weights_and_get_sampling_client for weight synchronization and client-side reward/advantage computation.
  • cookbook/client/twinkle/transformer/grpo.py
    • Added a new GRPO training example using the Twinkle client API.
    • Illustrates model.save() for checkpointing and passing adapter_uri to the sampler for weight sync.
    • Configures MultiLoraTransformersModel with GRPOLoss, optimizer, and LR scheduler.
  • cookbook/client/twinkle/transformer/sampler.py
    • Added a new example script for text generation inference using the Twinkle client and VLLMSampler.
    • Demonstrates preparing inputs, configuring sampling parameters, and using adapter_uri for LoRA inference.
  • cookbook/client/twinkle/transformer/server_config.yaml
    • Updated model service name and route prefix from Qwen2.5-7B-Instruct to Qwen2.5-0.5B-Instruct.
    • Added a new Sampler Service configuration for Qwen2.5-0.5B-Instruct using vllm with detailed engine_args and adapter_config.
  • docs/source/使用指引/服务端和客户端/服务端.md
    • Added extensive documentation on Ray cluster configuration, including starting Head and Worker nodes.
    • Provided examples for a 3-node cluster setup.
    • Explained how ranks in YAML configuration relate to Ray nodes for different services.
  • src/twinkle/infra/_ray/ray_helper.py
    • Modified create_workers to avoid setting CUDA_VISIBLE_DEVICES environment variable for CPU-only deployments.
  • src/twinkle/model/megatron/multi_lora_megatron.py
    • Added self.active_group = None to the __init__ method for compatibility.
  • src/twinkle/model/transformers/multi_lora_transformers.py
    • Added self.active_group = None to the __init__ method for compatibility.
  • src/twinkle/server/tinker/sampler.py
    • Removed the local parse_adapter_uri function.
    • Updated _do_sample to use checkpoint_manager.parse_adapter_uri for centralized adapter URI parsing.
  • src/twinkle/server/twinkle/common/serialize.py
    • Introduced _serialize_data_slice and _deserialize_data_slice helper functions.
    • Modified serialize_object and deserialize_object to correctly handle DatasetMeta.data_slice for various iterable types.
  • src/twinkle/server/twinkle/sampler.py
    • Refactored SamplerManagement class to inherit from AdapterManagerMixin for automated adapter lifecycle management.
    • Introduced Pydantic models (SampleRequest, SampleResponseModel, SetTemplateRequest, etc.) for API endpoints.
    • Updated __init__ to handle nproc_per_node, sampler_type, engine_args, and adapter_config.
    • Enhanced sample method to process adapter_uri, num_samples, and convert inputs/outputs using Pydantic models.
    • Removed manual adapter heartbeat and expiration logic, now handled by AdapterManagerMixin.
  • src/twinkle_client/processor/grpo.py
    • Added a new client wrapper for GRPOLossProcessor to interact with server-side HTTP endpoints.
  • src/twinkle_client/sampler/vllm_sampler.py
    • Updated server_url construction in __init__ to use /samplers/{model_id} instead of /models/{model_id}.
    • Modified sample method signature to include adapter_uri and num_samples parameters.
    • Adjusted http_post json_data in sample to pass adapter_uri and num_samples.
Activity
  • The pull request introduces new features and examples related to reinforcement learning and client-server interactions.
  • It involves significant refactoring of the sampler server and client-side sampling logic.
  • New documentation has been added to guide users on Ray cluster setup for the server.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements for reinforcement learning workflows, including new GRPO training examples for both Tinker and Twinkle clients, a major refactoring of the sampler server for improved robustness and functionality, and updated client-side APIs. The changes are extensive and well-structured. I've identified a few critical issues in the new example code and server implementation that need to be addressed, along with some suggestions for improving the documentation and code quality. Overall, this is a great step forward for the project.

Comment on lines +165 to +182
old_logps_list = []
completion_lengths = []

sequences = sample_response.get('sequences', [])
for seq in sequences:
input_features.append(seq.get('new_input_feature', seq))
old_logps_list.append(seq.get('logprobs', []))
completion_lengths.append(len(seq.get('tokens', [])))

if not input_features:
logger.warning(f"Step {step}: No valid samples, skipping")
step += 1
continue

# ========== 3. Compute rewards ==========
total_rewards, format_rewards, accuracy_rewards = compute_rewards(
input_features)
metrics.accumulate(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are a couple of issues in this block that will prevent the example from running correctly:

  1. Incorrect input_features for compute_rewards: The compute_rewards function expects a list of trajectories (dictionaries with a messages key), but input_features is populated with dictionaries from the sampler's response ({'tokens': ..., 'logprobs': ...}). You need to decode the generated tokens into text and construct trajectory dictionaries, similar to the tinker example.

  2. Incorrect inputs for model.forward_backward: The model.forward_backward method expects a list of InputFeature objects representing the full prompt + completion sequence. The current input_features list does not have the correct structure.

  3. Missing Tokenizer: To decode the tokens for reward calculation, a tokenizer is needed, but it's not initialized in this script.

I suggest restructuring this part of the training loop to correctly process the sampler's output. You'll need to initialize a tokenizer, use it to decode completions for reward calculation, and then construct new InputFeature objects for the training step by combining the prompt features with the generated sequences.


if body.adapter_uri:
from .common.io_utils import create_checkpoint_manager
token = get_token_from_request(request)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The import path for create_checkpoint_manager appears to be incorrect. Based on the file structure, io_utils.py is in the utils directory, not a common subdirectory. This relative import will likely cause a ModuleNotFoundError.

Suggested change
token = get_token_from_request(request)
from twinkle.server.utils.io_utils import create_checkpoint_manager

from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.metric import CompletionRewardMetric
from twinkle.server.tinker.common import input_feature_to_datum
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import input_feature_to_datum from twinkle.server.tinker.common is not used in this file. It's good practice to remove unused imports to keep the code clean.

# or None to use the base model
# ADAPTER_URI = None
# Example:
ADAPTER_URI = "twinkle://20260208_224851-fa3cdd11-default/weights/twinkle-epoch-2"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The example ADAPTER_URI contains EMPTY_TOKEN. While this is commented out, it could be confusing for users. It would be clearer to either use a more explicit placeholder like <YOUR-TOKEN-HERE> or provide a note explaining that EMPTY_TOKEN needs to be replaced with a valid token.

max_replicas: 1
target_ongoing_requests: 16
ray_actor_options:
num_cpus: 0.1 No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a newline at the end. It's a common convention to end files with a single newline character.

          num_cpus: 0.1


```bash
# 第二个 GPU 节点,使用 GPU 4-7,共 4 个 GPU
CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=10.28.252.9:6379 --num-gpus=4
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The IP address 10.28.252.9 is hardcoded in the example command. It would be helpful to add a note for users to replace this with the actual IP address of their Ray head node.


# Sampler 服务占用 Node 1(Worker 节点,GPU 4-7)
- name: sampler-Qwen2.5-7B-Instruct
route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The route_prefix here is /sampler/... (singular), but in the server_config.yaml and client implementations, it's /samplers/... (plural). This should be corrected to /samplers/... for consistency.

Suggested change
route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct
route_prefix: /samplers/Qwen/Qwen2.5-7B-Instruct

def __del__(self):
try:
heartbeat_manager.unregister_processor(self.processor_id)
except:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: clause is generally discouraged as it can catch unexpected system-exiting exceptions (like SystemExit or KeyboardInterrupt), making it harder to debug or interrupt the program. It's better to catch a more specific exception, like except Exception:.

Suggested change
except:
except Exception:

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a client/server RL (GRPO) cookbook example and extends the sampler HTTP API + client wrappers to support LoRA adapter loading via adapter_uri and multi-sample generation (num_samples). It also updates checkpoint/serialization utilities to better support HTTP-mode workflows used by the new examples.

Changes:

  • Update Twinkle sampler server and client to use /samplers/... routes and accept adapter_uri + num_samples.
  • Add GRPO training and sampler cookbook examples; add a GRPO processor client wrapper.
  • Add data_slice serialization support for DatasetMeta in HTTP mode and centralize adapter-URI parsing in the checkpoint manager.

Reviewed changes

Copilot reviewed 27 out of 27 changed files in this pull request and generated 24 comments.

Show a summary per file
File Description
src/twinkle_client/sampler/vllm_sampler.py Update sampler client URL routing and request payload to include adapter_uri/num_samples.
client_tools/client_generator.py Update sampler client code generation to match the new sampler route and request fields.
src/twinkle_client/processor/grpo.py Add a new GRPO processor client wrapper for server-side preprocessing.
src/twinkle/server/twinkle/sampler.py Major rewrite of the sampler service API using Pydantic request/response models + adapter lifecycle handling.
src/twinkle/server/utils/io_utils.py Add parse_adapter_uri() helper to checkpoint manager to resolve LoRA adapter paths.
src/twinkle/server/twinkle/common/serialize.py Support serializing/deserializing DatasetMeta.data_slice (e.g., range(...)) for HTTP mode.
src/twinkle/server/tinker/sampler.py Switch to centralized parse_adapter_uri() implementation.
src/twinkle/model/transformers/multi_lora_transformers.py Add active_group = None compatibility field.
src/twinkle/model/megatron/multi_lora_megatron.py Add active_group = None compatibility field.
src/twinkle/infra/_ray/ray_helper.py Adjust CPU worker env-var handling for Ray worker creation.
cookbook/client/twinkle/transformer/server_config.yaml Update example server config and add a sampler service definition.
cookbook/client/twinkle/transformer/sampler.py Add a Twinkle HTTP sampler inference example.
cookbook/client/twinkle/transformer/grpo.py Add a Twinkle HTTP GRPO training example using model.save() + adapter_uri.
cookbook/client/tinker/transformer/grpo.py Add a Tinker-compatible GRPO training example.
Comments suppressed due to low confidence (3)

client_tools/client_generator.py:760

  • The generated VLLMSampler.__init__ returns response.json(). Returning a non-None value from __init__ raises TypeError on instantiation, so any generated client will fail at runtime. Drop the return response.json() and instead just keep the response for validation / store needed values on self.
            model_id = model_id.split('://')[1]
        self.server_url = f'{self.server_url}/samplers/{model_id}'
        response = http_post(
            url=f'{self.server_url}/create',
            json_data=kwargs
        )
        response.raise_for_status()
    
    def _send_adapter_heartbeat(self):
        """Internal method to send adapter heartbeat."""
        if not self.adapter_name:
            return
        response = http_post(
            url=f'{self.server_url}/heartbeat',
            json_data={'adapter_name': self.adapter_name}
        )

src/twinkle/server/twinkle/sampler.py:325

  • deploy_options defaults to None, but is expanded with **deploy_options when calling SamplerManagement.options(...). If deploy_options is omitted by the caller, this will raise a TypeError. Consider defaulting to an empty dict (e.g., deploy_options = deploy_options or {}) before using it.
        nproc_per_node, device_group, device_mesh, sampler_type, engine_args, adapter_config, **kwargs)

src/twinkle_client/sampler/vllm_sampler.py:45

  • __init__ returns response.json(). In Python, __init__ must return None; returning a dict will raise TypeError: __init__() should return None, not 'dict' when constructing VLLMSampler. This will break the new cookbook examples that instantiate VLLMSampler. Remove the return value and store any needed fields on self instead.
        self.adapter_name = None
        if '://' in model_id:
            model_id = model_id.split('://')[1]
        self.server_url = f'{self.server_url}/samplers/{model_id}'
        response = http_post(
            url=f'{self.server_url}/create',
            json_data=kwargs
        )
        response.raise_for_status()
    
    def _send_adapter_heartbeat(self):
        """Internal method to send adapter heartbeat."""
        if not self.adapter_name:
            return
        response = http_post(
            url=f'{self.server_url}/heartbeat',

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 20 to +23
from twinkle import DeviceGroup, DeviceMesh
from twinkle.data_format import Trajectory, InputFeature
from twinkle.sampler import vLLMSampler
from twinkle.server.utils.validation import verify_request_token
from twinkle.data_format import Trajectory, InputFeature, SamplingParams
from twinkle.server.utils.adapter_manager import AdapterManagerMixin
from twinkle.server.utils.validation import verify_request_token, get_token_from_request
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import path looks incorrect: twinkle.sampler.types does not exist in this repo (sampling dataclasses live under twinkle.data_format.sampling). As-is, the sampler service will fail to import at startup. Update these imports to the correct module (and drop unused ones if needed).

Copilot uses AI. Check for mistakes.
sequences.append({
'stop_reason': seq.stop_reason,
'tokens': list(seq.tokens),
'logprobs': list(seq.logprobs) if seq.logprobs is not None else None,
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HTTP response currently strips decoded and new_input_feature from SampledSequence. The new GRPO example (and existing sampler code) expects new_input_feature to be present to feed sampled continuations back into training. Either include these fields in the response (and in SampleResponseModel), or update the cookbook/examples to not rely on them.

Suggested change
'logprobs': list(seq.logprobs) if seq.logprobs is not None else None,
'logprobs': list(seq.logprobs) if seq.logprobs is not None else None,
'decoded': getattr(seq, 'decoded', None),
'new_input_feature': getattr(seq, 'new_input_feature', None),

Copilot uses AI. Check for mistakes.
Comment on lines +83 to +90
num_samples: int = 1,
) -> Dict[str, Any]:
"""Sample from the model.

Args:
inputs: List of Trajectory or InputFeature to sample from.
sampling_params: Sampling parameters dict.
adapter_name: Adapter name.
adapter_name: Adapter name for LoRA inference.
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This client class subclasses twinkle.sampler.base.Sampler, whose abstract sample() contract returns SampleResponse. Changing the override to return Dict[str, Any] breaks the base-class type contract and can confuse users/type-checkers. Consider either (a) not inheriting from Sampler for HTTP clients, or (b) returning/constructing a SampleResponse object on the client side to preserve the API.

Copilot uses AI. Check for mistakes.
Comment on lines +44 to +45
MODEL_ID = 'ms://Qwen/Qwen2.5-0.5B-Instruct'
NUM_GENERATIONS = 8
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This GRPO example uses MODEL_ID = 'ms://Qwen/Qwen2.5-3B-Instruct', but the provided server_config.yaml in the same directory deploys Qwen2.5-0.5B-Instruct. Unless the server config is updated accordingly, the model/sampler routes won’t exist for 3B. Consider aligning the example with the shipped config, or add a clear note that the config must be changed to match MODEL_ID.

Copilot uses AI. Check for mistakes.
Comment on lines +299 to +306
self.sampler.add_adapter_to_sampler(full_adapter_name, config)

self.register_adapter(full_adapter_name, token)
allowed, reason = self.check_adapter_limit(token, True)
if not allowed:
raise RuntimeError(reason)

return AddAdapterResponse(adapter_name=full_adapter_name)
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-token adapter limit enforcement happens after the adapter is added to the sampler and registered. If the limit is exceeded, this raises but leaves the adapter loaded/registered (resource leak and limit bypass). Check check_adapter_limit(token, True) before adding/registering, or rollback (remove adapter + unregister) on failure.

Copilot uses AI. Check for mistakes.
from twinkle.sampler import vLLMSampler
from twinkle.server.utils.validation import verify_request_token
from twinkle.data_format import Trajectory, InputFeature, SamplingParams
from twinkle.server.utils.adapter_manager import AdapterManagerMixin
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'SampleResponse' is not used.
Import of 'SampledSequence' is not used.

Suggested change
from twinkle.server.utils.adapter_manager import AdapterManagerMixin

Copilot uses AI. Check for mistakes.
from twinkle.data_format import InputFeature
from .base import InputProcessor

class GRPOLossProcessor(InputProcessor):
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call InputProcessor.init during initialization. (GRPOLossProcessor.init may be missing a call to a base class init)

Copilot uses AI. Check for mistakes.
from twinkle.data_format import InputFeature
from .base import InputProcessor

class GRPOLossProcessor(InputProcessor):
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call InputProcessor.del during finalization. (GRPOLossProcessor.del may be missing a call to a base class del)

Copilot uses AI. Check for mistakes.
def __del__(self):
try:
heartbeat_manager.unregister_processor(self.processor_id)
except:
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except block directly handles BaseException.

Suggested change
except:
except Exception:

Copilot uses AI. Check for mistakes.
def __del__(self):
try:
heartbeat_manager.unregister_processor(self.processor_id)
except:
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Copilot uses AI. Check for mistakes.
@Yunnglin Yunnglin merged commit 65ae1ed into dev Feb 8, 2026
0 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants