|
| 1 | +# Tinker-Compatible Client - DPO (Direct Preference Optimization) Training with LoRA |
| 2 | +# |
| 3 | +# This script demonstrates how to fine-tune a language model using DPO |
| 4 | +# through the Tinker-compatible client API. |
| 5 | +# |
| 6 | +# Training flow per step: |
| 7 | +# 1. forward_backward with 'cross_entropy' + disable_lora=True |
| 8 | +# → base-model forward pass; LoRA weights are NOT in the computation graph |
| 9 | +# so backward accumulates zero LoRA gradients (safe to discard). |
| 10 | +# 2. Attach returned per-token ref logps to each datum's loss_fn_inputs. |
| 11 | +# 3. forward_backward with 'importance_sampling' |
| 12 | +# → server detects ref_logps and switches to DPOLoss + DPOMetric. |
| 13 | +# 4. optim_step → update LoRA, DPO metrics returned automatically. |
| 14 | +# |
| 15 | +# The server must be running first (see server.py and server_config.yaml). |
| 16 | + |
| 17 | +import numpy as np |
| 18 | +import torch |
| 19 | +from tqdm import tqdm |
| 20 | +from typing import Any, Dict, List |
| 21 | + |
| 22 | +from tinker import types |
| 23 | +from twinkle import init_tinker_client, get_logger |
| 24 | +from twinkle.dataset import Dataset, DatasetMeta |
| 25 | +from twinkle.dataloader import DataLoader |
| 26 | +from twinkle.preprocessor import EmojiDPOProcessor |
| 27 | +from twinkle.server.common import input_feature_to_datum |
| 28 | + |
| 29 | +logger = get_logger() |
| 30 | + |
| 31 | +# Initialize the Tinker client before importing ServiceClient |
| 32 | +init_tinker_client() |
| 33 | + |
| 34 | +from tinker import ServiceClient # noqa: E402 (must follow init_tinker_client) |
| 35 | + |
| 36 | +# --------------------------------------------------------------------------- |
| 37 | +# Configuration |
| 38 | +# --------------------------------------------------------------------------- |
| 39 | +base_model = 'Qwen/Qwen3.5-4B' |
| 40 | +base_url = 'http://localhost:8000' |
| 41 | +api_key = 'EMPTY_API_KEY' |
| 42 | +dataset_id = 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji' |
| 43 | + |
| 44 | +batch_size = 4 |
| 45 | +learning_rate = 1e-4 |
| 46 | +dpo_beta = 0.1 |
| 47 | +sft_weight = 1.0 |
| 48 | +loss_type = 'sigmoid' |
| 49 | +max_length = 2048 |
| 50 | +lora_rank = 8 |
| 51 | +system_prompt = 'You are a helpful assistant.' |
| 52 | + |
| 53 | + |
| 54 | +# --------------------------------------------------------------------------- |
| 55 | +# Dataset helpers (reused from twinkle/self_host/dpo.py) |
| 56 | +# --------------------------------------------------------------------------- |
| 57 | + |
| 58 | +def create_dpo_dataset(): |
| 59 | + """Create DPO dataset with positive/negative format.""" |
| 60 | + dataset = Dataset(DatasetMeta(dataset_id, data_slice=range(600))) |
| 61 | + dataset.set_template('Qwen3_5Template', model_id=f'ms://{base_model}', max_length=max_length) |
| 62 | + dataset.map( |
| 63 | + EmojiDPOProcessor, |
| 64 | + init_args={'system': system_prompt}, |
| 65 | + ) |
| 66 | + # EmojiDPOProcessor returns {'positive': InputFeature, 'negative': InputFeature, ...} |
| 67 | + # encode handles this format automatically |
| 68 | + dataset.encode() |
| 69 | + return dataset |
| 70 | + |
| 71 | + |
| 72 | +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| 73 | + """Reorganise batch into DP-safe interleaved format [pos_1, neg_1, pos_2, neg_2, ...]. |
| 74 | +
|
| 75 | + Args: |
| 76 | + batch: List of rows, each with 'positive' and 'negative' InputFeatures. |
| 77 | +
|
| 78 | + Returns: |
| 79 | + Interleaved list so each DP worker slice contains complete pairs. |
| 80 | + """ |
| 81 | + result = [] |
| 82 | + for row in batch: |
| 83 | + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} |
| 84 | + pos_sample = {**base_fields, **row['positive']} |
| 85 | + neg_sample = {**base_fields, **row['negative']} |
| 86 | + result.append(pos_sample) |
| 87 | + result.append(neg_sample) |
| 88 | + return result |
| 89 | + |
| 90 | + |
| 91 | +# --------------------------------------------------------------------------- |
| 92 | +# Training |
| 93 | +# --------------------------------------------------------------------------- |
| 94 | + |
| 95 | +def train(): |
| 96 | + # Step 1: Prepare dataset & dataloader |
| 97 | + logger.info('Loading DPO dataset...') |
| 98 | + dataset = create_dpo_dataset() |
| 99 | + dataloader = DataLoader(dataset=dataset, batch_size=batch_size) |
| 100 | + logger.info(f'Dataset ready: {len(dataloader)} steps per epoch') |
| 101 | + |
| 102 | + # Step 2: Connect to server and create LoRA training client |
| 103 | + service_client = ServiceClient(base_url=base_url, api_key=api_key) |
| 104 | + training_client = service_client.create_lora_training_client( |
| 105 | + base_model=base_model, |
| 106 | + rank=lora_rank, |
| 107 | + ) |
| 108 | + logger.info(f'LoRA training client created (rank={lora_rank})') |
| 109 | + logger.info(f'Starting DPO training: loss_type={loss_type}, beta={dpo_beta}, lr={learning_rate}') |
| 110 | + |
| 111 | + # Step 3: Training loop |
| 112 | + for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)): |
| 113 | + # Normalise numpy / torch tensors to plain Python lists for serialisation |
| 114 | + for row in batch: |
| 115 | + for key in list(row.keys()): |
| 116 | + if isinstance(row[key], np.ndarray): |
| 117 | + row[key] = row[key].tolist() |
| 118 | + elif isinstance(row[key], torch.Tensor): |
| 119 | + row[key] = row[key].cpu().numpy().tolist() |
| 120 | + |
| 121 | + # Build interleaved [pos, neg, pos, neg, ...] batch |
| 122 | + dpo_batch = prepare_dpo_batch(batch) |
| 123 | + |
| 124 | + # Convert each InputFeature dict to a Tinker Datum |
| 125 | + input_datums = [input_feature_to_datum(row) for row in dpo_batch] |
| 126 | + |
| 127 | + # ----------------------------------------------------------------- |
| 128 | + # A. Reference forward pass (base model, disable_lora=True) |
| 129 | + # LoRA weights are outside the computation graph → backward |
| 130 | + # produces zero LoRA gradients, so this call is safe. |
| 131 | + # ----------------------------------------------------------------- |
| 132 | + ref_result = training_client.forward_backward( |
| 133 | + input_datums, |
| 134 | + 'cross_entropy', |
| 135 | + loss_fn_config={'disable_lora': True}, |
| 136 | + ).result() |
| 137 | + |
| 138 | + # ----------------------------------------------------------------- |
| 139 | + # B. Attach per-token ref logps to each datum's loss_fn_inputs |
| 140 | + # ----------------------------------------------------------------- |
| 141 | + for datum, ref_out in zip(input_datums, ref_result.loss_fn_outputs): |
| 142 | + ref_logprobs_np = np.array(ref_out['logprobs'].tolist(), dtype=np.float32) |
| 143 | + datum.loss_fn_inputs['ref_logps'] = types.TensorData.from_numpy(ref_logprobs_np) |
| 144 | + |
| 145 | + # ----------------------------------------------------------------- |
| 146 | + # C. DPO forward_backward |
| 147 | + # Server detects ref_logps → sets DPOLoss + DPOMetric automatically. |
| 148 | + # Optional DPO hyper-params can be forwarded via loss_fn_config. |
| 149 | + # ----------------------------------------------------------------- |
| 150 | + fwdbwd_result = training_client.forward_backward( |
| 151 | + input_datums, |
| 152 | + 'importance_sampling', |
| 153 | + loss_fn_config={ |
| 154 | + 'dpo_beta': dpo_beta, |
| 155 | + 'dpo_loss_type': loss_type, |
| 156 | + 'dpo_sft_weight': sft_weight, |
| 157 | + }, |
| 158 | + ).result() |
| 159 | + |
| 160 | + # ----------------------------------------------------------------- |
| 161 | + # D. Optimizer step — DPOMetric is calculated automatically on the |
| 162 | + # server and returned inside optim_result.metrics. |
| 163 | + # ----------------------------------------------------------------- |
| 164 | + optim_result = training_client.optim_step( |
| 165 | + types.AdamParams(learning_rate=learning_rate) |
| 166 | + ).result() |
| 167 | + |
| 168 | + dpo_loss = fwdbwd_result.metrics.get('loss:avg', 'N/A') |
| 169 | + logger.info(f'[Step {step}] dpo_loss={dpo_loss} | metrics={optim_result.metrics}') |
| 170 | + |
| 171 | + # Step 4: Save checkpoint |
| 172 | + save_result = training_client.save_state('dpo-lora-final').result() |
| 173 | + logger.info(f'Saved checkpoint: {save_result.path}') |
| 174 | + |
| 175 | + # Step 5: (Optional) Upload to ModelScope Hub |
| 176 | + # YOUR_USER_NAME = 'your_username' |
| 177 | + # hub_model_id = f'{YOUR_USER_NAME}/twinkle-tinker-dpo-lora' |
| 178 | + # training_client.publish_checkpoint_from_tinker_path(save_result.path).result() |
| 179 | + # logger.info(f'Uploaded checkpoint to hub: {hub_model_id}') |
| 180 | + |
| 181 | + |
| 182 | +if __name__ == '__main__': |
| 183 | + train() |
0 commit comments