Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions .github/copilot-instructions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Twinkle AI Coding Agent Guidelines

These instructions help AI agents work productively in this repo. Focus on concrete repo patterns and workflows.

## Big Picture
- **Goal:** Training and serving LLMs with multi-adapter LoRA, efficient data handling, and distributed execution across Ray and Torch.
- **Core Modules:**
- Infrastructure & distributed orchestration: [src/twinkle/infra/__init__.py](src/twinkle/infra/__init__.py)
- Device layout & platform abstraction: [src/twinkle/utils/platform.py](src/twinkle/utils/platform.py), [src/twinkle/utils/framework.py](src/twinkle/utils/framework.py)
- Model stack (Transformers + Multi-LoRA): [src/twinkle/model/multi_lora_transformers.py](src/twinkle/model/multi_lora_transformers.py)
- Sampler (vLLM integration): [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py)
- Losses & metrics: [src/twinkle/loss](src/twinkle/loss), [src/twinkle/metric](src/twinkle/metric)
- Templates & preprocessing: [src/twinkle/template](src/twinkle/template), [src/twinkle/preprocessor](src/twinkle/preprocessor)
- Model/Processor HTTP services via Ray Serve: [src/twinkle/server/twinkle](src/twinkle/server/twinkle)
- Hub integrations (ModelScope/HF): [src/twinkle/hub/hub.py](src/twinkle/hub/hub.py)

## Architecture & Patterns
- **Lazy import surface:** [src/twinkle/__init__.py](src/twinkle/__init__.py) exposes a small, lazy API (`_LazyModule`), import public symbols from here when possible.
- **Distributed mode selection:** `twinkle.infra.initialize()` toggles between local and Ray modes. Ray mode requires `TWINKLE_MODE=ray` or `initialize(mode='ray', ...)`.
- **Remote execution decorators:**
- `remote_class()` wraps classes for Ray placement; auto-injects `DeviceMesh` if missing.
- `remote_function(dispatch='slice', execute='all', collect='none')` patches methods for distributed dispatch/collect.
- See usage in [src/twinkle/model/multi_lora_transformers.py](src/twinkle/model/multi_lora_transformers.py) and [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py).
- **Device topology:** Represented by `DeviceMesh`/`DeviceGroup`. Visualize with `twinkle.infra.get_device_placement()`; examples in [tests/infra/test_infra_graph.py](tests/infra/test_infra_graph.py).
- **Platform abstractions:** `GPU`/`NPU` selection via env and device discovery. Rank/world size read from env (`RANK`, `WORLD_SIZE`, etc.). See [src/twinkle/utils/platform.py](src/twinkle/utils/platform.py).
- **Hub usage:** `HubOperation` routes to HF or ModelScope by `hf://` or `ms://` prefixes. Dataset/model download/push helpers in [src/twinkle/hub/hub.py](src/twinkle/hub/hub.py).
- **Plugin loading:** Use `Plugin.load_plugin(id, Base)` for remote code from hubs; guarded by `trust_remote_code()` to prevent unsafe execution. See [src/twinkle/utils/plugin.py](src/twinkle/utils/plugin.py).
- **Multi-LoRA conventions:**
- `MultiLoraTransformersModel` wraps a base Transformers model via `MultiAdapter` to manage multiple LoRA adapters.
- FSDP is unsupported for Multi-LoRA (`fsdp_world_size == 1` enforced). Adapter params are strictly controlled to avoid training base weights.
- Adapter ops are routed through remote functions and grouped by DP process groups.

## Developer Workflows
- **Install:** Python 3.11+. Install with Poetry or pip.
- Poetry: `poetry install --with transformers,ray`
- Pip (editable): `pip install -e .[transformers,ray]`
- **Run tests:**
- Unit tests: `python -m unittest tests/infra/test_infra_graph.py`
- **Local single-process dev:**
- Initialize infra: `twinkle.initialize(mode='local', seed=42)`
- Inspect device placement: call `twinkle.infra.get_device_placement()`.
- **Ray Serve demo (HTTP services):**
- Config and launcher: [cookbook/client/server.py](cookbook/client/server.py), [cookbook/client/server_config.yaml](cookbook/client/server_config.yaml)
- Start:
- `python cookbook/client/server.py`
- Endpoints print on startup (default `localhost:8000`).
- Model app binds `MultiLoraTransformersModel` and exposes routes like `/add_adapter_to_model`, `/forward`, `/calculate_loss`, etc. See [src/twinkle/server/twinkle/model.py](src/twinkle/server/twinkle/model.py).
- **vLLM inference:** Use `VLLMSampler` with engine args; LoRA weight sync via `patch.vllm_lora_weights`. See [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py).

## Conventions & Gotchas
- **Safety:** Remote plugin code requires `trust_remote_code()` true; avoid loading arbitrary strings into adapter configs (enforced in Multi-LoRA).
- **Env-driven ranks:** Many utilities read ranks/world size from env; set `WORLD_SIZE`, `RANK`, `LOCAL_RANK` when using torchrun.
- **Determinism:** `seed_everything(seed, full_determinism)` controls CUDA/NPU determinism; may set envs like `CUDA_LAUNCH_BLOCKING`.
- **Adapter lifecycle:** Server auto-removes inactive adapters (heartbeat required); per-token adapter limits are enforced. See cleanup in [src/twinkle/server/twinkle/model.py](src/twinkle/server/twinkle/model.py).
- **Templates:** Tokenization/encode via `Template` (e.g., `Qwen3Template`), producing `InputFeature` for model forward. See [src/twinkle/template/base.py](src/twinkle/template/base.py).

## Examples
- **Visualize a custom mesh:** create `DeviceMesh` and call `get_device_placement()`; example in [tests/infra/test_infra_graph.py](tests/infra/test_infra_graph.py).
- **Add LoRA adapter via HTTP:** POST to `/add_adapter_to_model` with serialized `LoraConfig`; see server routes in [src/twinkle/server/twinkle/model.py](src/twinkle/server/twinkle/model.py).
- **Sample with vLLM:** Configure `VLLMSampler`, set `Template`/`Processor`, then `sample()` on `Trajectory` list; see [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py).

---
Questions or gaps? Tell us where guidance is unclear (e.g., missing run scripts, Ray cluster setup), and we’ll refine this document.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,4 @@ megatron_output/

# ast template
ast_index_file.py
test_cookbook/
2 changes: 1 addition & 1 deletion cookbook/sft/multi_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)


twinkle.initialize(mode='ray', nproc_per_node=4, groups=device_group, global_device_mesh=device_mesh, lazy_collect=False)
twinkle.initialize(mode='local', nproc_per_node=4, groups=device_group, global_device_mesh=device_mesh, lazy_collect=False)


def train():
Expand Down
12 changes: 9 additions & 3 deletions cookbook/sft/streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
]


# device_mesh = DeviceMesh(
# device_type='cuda',
# mesh=np.array([[0,1], [2,3]]),
# mesh_dim_names=('dp', 'fsdp')
# )

device_mesh = DeviceMesh(
device_type='cuda',
mesh=np.array([[0,1], [2,3]]),
mesh_dim_names=('dp', 'fsdp')
device_type='cuda',
mesh=np.array([0,1,2,3]),
mesh_dim_names=('dp',)
)

twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh)
Expand Down
99 changes: 99 additions & 0 deletions cookbook/tinker/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#%%

from twinkle_client import init_tinker_compat_client
service_client = init_tinker_compat_client(base_url='http://localhost:8000')

print("Available models:")
for item in service_client.get_server_capabilities().supported_models:
print("- " + item.model_name)

#%%
base_model = "ms://Qwen/Qwen2.5-0.5B-Instruct"
training_client = service_client.create_lora_training_client(
base_model=base_model
)

#%%
# Create some training examples
examples = [
{"input": "banana split", "output": "anana-bay plit-say"},
{"input": "quantum physics", "output": "uantum-qay ysics-phay"},
{"input": "donut shop", "output": "onut-day op-shay"},
{"input": "pickle jar", "output": "ickle-pay ar-jay"},
{"input": "space exploration", "output": "ace-spay exploration-way"},
{"input": "rubber duck", "output": "ubber-ray uck-day"},
{"input": "coding wizard", "output": "oding-cay izard-way"},
]

# Convert examples into the format expected by the training client
from tinker import types
from modelscope import AutoTokenizer
# Get the tokenizer from the training client
# tokenizer = training_client.get_tokenizer() # NOTE: network call huggingface
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=True)

def process_example(example: dict, tokenizer) -> types.Datum:
# Format the input with Input/Output template
# For most real use cases, you'll want to use a renderer / chat template,
# (see later docs) but here, we'll keep it simple.
prompt = f"English: {example['input']}\nPig Latin:"

prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
prompt_weights = [0] * len(prompt_tokens)
# Add a space before the output string, and finish with double newline
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)

tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights

input_tokens = tokens[:-1]
target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted.
weights = weights[1:]

# A datum is a single training example for the loss function.
# It has model_input, which is the input sequence that'll be passed into the LLM,
# loss_fn_inputs, which is a dictionary of extra inputs used by the loss function.
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
)

processed_examples = [process_example(ex, tokenizer) for ex in examples]

# Visualize the first example for debugging purposes
datum0 = processed_examples[0]
print(f"{'Input':<20} {'Target':<20} {'Weight':<10}")
print("-" * 50)
for i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())):
print(f"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}")

#%%
import numpy as np
for _ in range(6):
fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))

# Wait for the results
fwdbwd_result = fwdbwd_future.result()
optim_result = optim_future.result()

# fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted
# average log loss per token.
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")

#%%
# First, create a sampling client. We need to transfer weights
sampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model')

# Now, we can sample from the model.
prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
params = types.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"]) # Greedy sampling
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)
result = future.result()
print("Responses:")
for i, seq in enumerate(result.sequences):
print(f"{i}: {repr(tokenizer.decode(seq.tokens))}")
# %%
48 changes: 48 additions & 0 deletions cookbook/tinker/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
os.environ['RAY_DEBUG'] = '1'
import ray
from omegaconf import OmegaConf
from ray import serve
from twinkle.server.tinker import build_model_app, build_server_app

ray.init(namespace="twinkle_cluster")
serve.shutdown()
import time
time.sleep(5)

file_dir = os.path.abspath(os.path.dirname(__file__))
config = OmegaConf.load(os.path.join(file_dir, 'server_config.yaml'))

APP_BUILDERS = {
'main:build_server_app': build_server_app,
'main:build_model_app': build_model_app,
# 'main:build_sampler_app': build_sampler_app,
}

for app_config in config.applications:
print(f"Starting {app_config.name} at {app_config.route_prefix}...")

builder = APP_BUILDERS[app_config.import_path]
args = OmegaConf.to_container(app_config.args, resolve=True) if app_config.args else {}

deploy_options = {}
deploy_config = app_config.deployments[0]
if 'autoscaling_config' in deploy_config:
deploy_options['autoscaling_config'] = OmegaConf.to_container(deploy_config.autoscaling_config)
if 'ray_actor_options' in deploy_config:
deploy_options['ray_actor_options'] = OmegaConf.to_container(deploy_config.ray_actor_options)

app = builder(
deploy_options=deploy_options,
route_prefix=app_config.route_prefix,
**{k: v for k, v in args.items()}
)

serve.run(app, name=app_config.name, route_prefix=app_config.route_prefix)

print("\nAll applications started!")
print("Endpoints:")
for app_config in config.applications:
print(f" - http://localhost:8000{app_config.route_prefix}")

input("\nPress Enter to stop the server...")
42 changes: 42 additions & 0 deletions cookbook/tinker/server_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
proxy_location: EveryNode
http_options:
host: 0.0.0.0
port: 8000

applications:
- name: server
route_prefix: /api/v1
import_path: main:build_server_app
args:

deployments:
- name: TinkerCompatServer
autoscaling_config:
min_replicas: 1
max_replicas: 1
target_ongoing_requests: 128
ray_actor_options:
num_cpus: 0.1

- name: models
route_prefix: /api/v1/model
import_path: main:build_model_app
args:
nproc_per_node: 2
device_group:
name: model
ranks: [ 0,1 ]
device_type: cuda
device_mesh:
device_type: cuda
mesh: [0,1 ]
mesh_dim_names: ['dp']
deployments:
- name: ModelManagement
autoscaling_config:
min_replicas: 1
max_replicas: 1
target_ongoing_requests: 16
ray_actor_options:
num_cpus: 0.1

Loading