diff --git a/README.md b/README.md index cc92fc2f..d0693836 100644 --- a/README.md +++ b/README.md @@ -135,8 +135,6 @@ supported on Twinkle✨ framework. | | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) | | deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 1.5B/7B/14B/32B | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | -For more detailed model support list 👉 [Quick Start](docs/source_en/Usage%20Guide/Quick-Start.md) - ## Sample Code Below are some of the capabilities demonstrated in the example code. For a complete introduction to training capabilities, diff --git a/README_ZH.md b/README_ZH.md index 10c1fc88..11a6cccc 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -116,8 +116,6 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl | | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) | | deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 1.5B/7B/14B/32B | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | -更详细的模型支持列表 👉 [快速开始.md](docs/source_zh/使用指引/快速开始.md) - ## 示例代码 下面列出了示例代码的一部分能力。完整的训练能力介绍请参考[快速开始](docs/source_zh/使用指引/快速开始.md)以及[cookbook](cookbook)。 diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py new file mode 100644 index 00000000..a93cd705 --- /dev/null +++ b/cookbook/mm/fsdp2.py @@ -0,0 +1,97 @@ +from peft import LoraConfig +from tqdm import tqdm + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.data_format import Trajectory, Message +from twinkle.dataloader import DataLoader +from twinkle.dataset import LazyDataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import Preprocessor + +# Construct a device_mesh, fsdp=2 +device_mesh = DeviceMesh.from_sizes(fsdp_size=2) +# use torchrun mode +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +logger = get_logger() + + +class LatexOCRProcessor(Preprocessor): + + def __call__(self, row) -> Trajectory: + return Trajectory( + messages=[ + Message(role='user', content='Using LaTeX to perform OCR on the image.', images=[row['image']]), + Message(role='assistant', content=row['text']), + ] + ) + + +def eval(model): + # 100 Samples + dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(100))) + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.map(LatexOCRProcessor) + dataset.encode() + dataloader = DataLoader(dataset=dataset, batch_size=8) + for step, batch in tqdm(enumerate(dataloader)): + model.forward_only(inputs=batch) + model.calculate_loss() + metrics = model.calculate_metric(is_training=False) + return metrics + + +def train(): + # 2000 samples + dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000))) + # Set template to prepare encoding + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B', max_length=1024) + # Preprocess the dataset to standard format + dataset.map(LatexOCRProcessor) + # Encode dataset + dataset.encode() + # Global batch size = 4, for GPUs, so 2 sample per GPU + dataloader = DataLoader(dataset=dataset, batch_size=4) + # Use a TransformersModel + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration + model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', model_cls=Qwen3_5ForConditionalGeneration) + model.model._no_split_modules = {'Qwen3_5DecoderLayer'} + + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + + # Add a lora to model, with name `default` + # Comment this to use full-parameter training + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + # Add Optimizer for lora `default` + model.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') + model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + # Add LRScheduler for lora `default` + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + logger.info(get_device_placement()) + # Print the training config + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + loss_metric = 99.0 + for step, batch in enumerate(dataloader): + # Do forward and backward + model.forward_backward(inputs=batch) + # Step + model.clip_grad_and_step() + if step % 20 == 0: + # Print metric + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + if step > 0 and step % 40 == 0: + metrics = eval(model) + logger.info(f'Eval metric: {metrics}') + metrics['step'] = step + if loss_metric > float(metrics['loss']): + model.save(f'checkpoint-{step}') + loss_metric = float(metrics['loss']) + model.save(f'last-checkpoint') + + +if __name__ == '__main__': + train() diff --git a/cookbook/mm/fsdp2.sh b/cookbook/mm/fsdp2.sh new file mode 100644 index 00000000..46e9f27f --- /dev/null +++ b/cookbook/mm/fsdp2.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 fsdp2.py diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 5624495d..ca37d724 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -20,7 +20,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -35,7 +35,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -43,8 +43,7 @@ def train(): # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') - model.model._no_split_modules = {'Qwen3_5DecoderLayer'} + model = TransformersModel(model_id='ms://Qwen/Qwen3-4B') lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 868b61c0..da6e2d28 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -10,7 +10,7 @@ from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() -MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +MODEL_ID = 'ms://Qwen/Qwen3-4B' DATASETS = 'ms://swift/self-cognition' device_group = [DeviceGroup( diff --git a/src/twinkle/hub/hub.py b/src/twinkle/hub/hub.py index 899de321..6e1653e1 100644 --- a/src/twinkle/hub/hub.py +++ b/src/twinkle/hub/hub.py @@ -165,7 +165,8 @@ def load_dataset(cls, subset_name: str, split: str, streaming: bool = False, - revision: Optional[str] = None): + revision: Optional[str] = None, + **kwargs): """Load a dataset from the repo Args: @@ -179,7 +180,7 @@ def load_dataset(cls, The Dataset instance """ hub = cls._get_hub_class(dataset_id) - return hub.load_dataset(cls.remove_source_type(dataset_id), subset_name, split, streaming, revision) + return hub.load_dataset(cls.remove_source_type(dataset_id), subset_name, split, streaming, revision, **kwargs) @classmethod def download_model(cls, diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index dc16f624..9780973b 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -370,7 +370,9 @@ def _patch_peft(_module): if isinstance(_module, PeftModel): _module.add_adapter(lora_tenant.adapter_name, config) else: - _module = get_peft_model(_module, config, lora_tenant.adapter_name) + _peft_model: PeftModel = get_peft_model(_module, config, lora_tenant.adapter_name) + _module.active_adapters = _peft_model.active_adapters + _module = _peft_model for name, submodule in _module.named_modules(): if isinstance(submodule, LoraLayer): diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 4386cc82..f4bccc53 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -57,11 +57,31 @@ def __init__( self.multi_adapter.save_initial_weights() # Active group for compatibility with single adapter self.active_group = None + self.handler = self.register_global_mm_forward_hook() def _check_adapter_valid(self, adapter_name: str): assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, ' f'current is: {adapter_name}') + def register_global_mm_forward_hook(self): + + def forward_hook(model, args, kwargs): + active_adapter = model.active_adapters[0] + active_adapter = self.multi_adapter.find_lora(active_adapter).tenant_adapter_name + optimizer_group = self.optimizer_group[active_adapter] + template = optimizer_group.template + assert template is not None + return template.pre_forward_hook(model, args, kwargs) + + model = self.strategy.unwrap_model(self.model) + return model.register_forward_pre_hook(forward_hook, with_kwargs=True) + + def register_mm_forward_hook(self, optimizer_group: OptimizerGroup): + pass + + def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup): + pass + def _lazy_wrap_model(self): pass diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 33f044af..e6fbc397 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -62,6 +62,7 @@ class OptimizerGroup: checkpoint_engine: CheckpointEngine = None _dp_group = None _device_mesh: DeviceMesh = None + _handler: Any = None def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: if gradient_accumulation_steps is None: @@ -284,11 +285,23 @@ def _lazy_wrap_model(self): assert optimizer is not None self.model, optimizer = self.strategy.wrap_model(self.model, optimizer) optimizer_group.optimizer = optimizer + self.register_mm_forward_hook(optimizer_group) else: # maybe forward_only, no optimizer_group available self.model = self.strategy.wrap_model(self.model) self._model_wrapped = True + def register_mm_forward_hook(self, optimizer_group: OptimizerGroup): + model = self.strategy.unwrap_model(self.model) + template = optimizer_group.template + assert template is not None + optimizer_group._handler = model.register_forward_pre_hook(template.pre_forward_hook, with_kwargs=True) + + def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup): + if optimizer_group._handler is not None: + optimizer_group._handler.remove() + optimizer_group._handler = None + @staticmethod def _should_enable_expert_parallel(expert_parallel_config: Optional[Dict[str, Any]], device_mesh: Optional[DeviceMesh]) -> bool: diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index ca021ff4..e7bce2d3 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -94,6 +94,10 @@ def to_tensor(_input): value = torch.from_numpy(value) elif isinstance(value, list) and isinstance(value[0], (int, float, np.number)): value = torch.tensor(value) + elif key in self.VLM_CONCAT_FIELDS: + if not isinstance(value[0], torch.Tensor): + value = [torch.tensor(v) for v in value] + value = torch.cat(value, dim=0) if isinstance(value, torch.Tensor): value = value.to(Platform.get_local_device()) if value.dim() == 1: @@ -260,7 +264,8 @@ def _create_4d_attention_mask(attention_mask): @staticmethod def _get_packed_seq_params(position_ids): - assert position_ids.shape[0] == 1 + if position_ids.shape[0] > 1: + position_ids = position_ids[:1] position_ids_f = position_ids.flatten() indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) @@ -305,7 +310,10 @@ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeat results = [] for _input in inputs: output = {} - _keys = ['input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask'] + _keys = [ + 'input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask', + 'pixel_values', 'image_grid_thw' + ] for key in list(_input.keys()): if key in _keys: output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key] @@ -361,6 +369,9 @@ def _collate_macro_batch(self, inputs: List[InputFeature]) -> InputFeature: for field, values in vlm_fields.items(): if values: + if values[0].dim() == 1: + # image_thw may be squeezed + values = [value.unsqueeze(0) for value in values] result[field] = torch.cat(values, dim=0) return result diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 346ca37e..324ce7ac 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import Template -from .qwen3_vl import Qwen3VLTemplate +from .qwen3_5_vl import Qwen3_5Template diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 488d53a8..dfa04a58 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -1,12 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import inspect import numpy as np import os from collections.abc import Mapping -from copy import deepcopy +from copy import copy, deepcopy from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union from twinkle.data_format import InputFeature, Message, Trajectory from twinkle.hub import HubOperation +from twinkle.utils import load_image, to_device from .utils import tokenize_with_assistant_labels, transfer_to_standard_message if TYPE_CHECKING: @@ -40,6 +42,8 @@ def __init__(self, else: from transformers import AutoTokenizer self.processor = AutoTokenizer.from_pretrained(model_id, **kwargs) + from transformers import AutoConfig + self.config = AutoConfig.from_pretrained(model_id, **kwargs) self.use_chat_template = use_chat_template self.max_length = max_length @@ -109,7 +113,12 @@ def _test_support_assistant_tokens_mask(self): self._template_support_assistant_tokens_mask = False def preprocess_image(self, image: ImageInput) -> 'Image.Image': - return image + if isinstance(image, dict): + if image.get('path'): + image = image['path'] + else: + image = image['bytes'] + return load_image(image) def preprocess_video(self, video: VideoInput) -> List['Image.Image']: return video @@ -205,60 +214,26 @@ def _roll_labels(self, input_feature: InputFeature) -> List[InputFeature]: return [input_feature] def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: - # TODO code untested messages = trajectory['messages'] - # Get images/videos from trajectory level (common case) or message level - traj_images = trajectory.get('images') or [] - traj_videos = trajectory.get('videos') or [] - - # Preprocess all trajectory-level images and videos - if traj_images and self.is_mm: - traj_images = self.preprocess_images(traj_images) - if traj_videos and self.is_mm: - traj_videos = self.preprocess_videos(traj_videos) - - # Distribute trajectory-level images to messages that contain placeholders - image_idx = 0 - video_idx = 0 new_messages = [] for message in messages: - # If message already has images/videos at message level, use those + message = copy(message) + content = message['content'] msg_images = message.get('images') msg_videos = message.get('videos') - - # If not, assign from trajectory level based on placeholder count - if msg_images is None and self.is_mm: - content = message.get('content', '') - if isinstance(content, str): - placeholder_count = content.count(self.image_placeholder) - if placeholder_count > 0 and image_idx < len(traj_images): - msg_images = traj_images[image_idx:image_idx + placeholder_count] - image_idx += placeholder_count - elif msg_images and self.is_mm: - # Preprocess message-level images - msg_images = self.preprocess_images(msg_images) - - if msg_videos is None and self.is_mm: - content = message.get('content', '') - if isinstance(content, str): - placeholder_count = content.count(self.video_placeholder) - if placeholder_count > 0 and video_idx < len(traj_videos): - msg_videos = traj_videos[video_idx:video_idx + placeholder_count] - video_idx += placeholder_count - elif msg_videos and self.is_mm: - # Preprocess message-level videos - msg_videos = self.preprocess_videos(msg_videos) - - # Create message with images/videos attached - msg_with_media = dict(message) + msg_audios = message.get('audios') if msg_images: - msg_with_media['images'] = msg_images + message['images'] = self.preprocess_images(msg_images) + assert len(message['images']) == content.count(self.image_placeholder) if msg_videos: - msg_with_media['videos'] = msg_videos - + message['videos'] = self.preprocess_images(msg_videos) + assert len(message['videos']) == content.count(self.video_placeholder) + if msg_audios: + message['audios'] = self.preprocess_audios(msg_audios) + assert len(message['audios']) == content.count(self.audio_placeholder) new_messages.append( - transfer_to_standard_message(msg_with_media, self.image_placeholder, self.video_placeholder, - self.is_mm)) + transfer_to_standard_message(message, self.image_placeholder, self.video_placeholder, + self.audio_placeholder, self.is_mm)) trajectory['messages'] = new_messages return [trajectory] @@ -377,65 +352,36 @@ def decode(self, token_ids: List[int], **kwargs) -> str: def batch_decode(self, token_ids: List[List[int]], **kwargs) -> List[str]: return [self.processor.decode(_ids, **kwargs) for _ids in token_ids] - def post_encode(self, model: 'torch.nn.Module', inputs: Dict[str, Any]) -> Dict[str, Any]: - """ - Transform inputs for model forward. - - Default: use helper methods for embedding merge. - Override if model handles internally (like Qwen3-VL). - """ - input_ids = inputs.get('input_ids') - if input_ids is None: - return inputs - - text_embeds = self._get_text_embeddings(model, input_ids) - vision_embeds = self._get_vision_embeddings(model, inputs) - - if vision_embeds is not None: - inputs_embeds = self._merge_vision_embeddings(text_embeds, vision_embeds, input_ids, inputs) - else: - inputs_embeds = text_embeds - - result = {k: v for k, v in inputs.items() if k != 'input_ids'} - result['inputs_embeds'] = inputs_embeds - return result - - def _get_text_embeddings(self, model: 'torch.nn.Module', input_ids: 'torch.Tensor') -> 'torch.Tensor': - """Get text embeddings from model.""" - embed_fn = None - if hasattr(model, 'get_input_embeddings'): - embed_fn = model.get_input_embeddings() - elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'): - embed_fn = model.model.embed_tokens - elif hasattr(model, 'language_model') and hasattr(model.language_model, 'embed_tokens'): - embed_fn = model.language_model.embed_tokens - - if embed_fn is None: - raise ValueError('Cannot find embedding layer in model') - - return embed_fn(input_ids) - - def _get_vision_embeddings(self, model: 'torch.nn.Module', inputs: Dict[str, Any]) -> Optional['torch.Tensor']: - """Get vision embeddings. Override in subclass.""" - return None - def _get_vision_token_id(self) -> Optional[int]: - """Get vision placeholder token ID. Override in subclass.""" - return self.processor.encode(self.image_placeholder) - - def _merge_vision_embeddings(self, text_embeds: 'torch.Tensor', vision_embeds: 'torch.Tensor', - input_ids: 'torch.Tensor', inputs: Dict[str, Any]) -> 'torch.Tensor': - """Merge vision embeddings at placeholder positions.""" - vision_token_id = self._get_vision_token_id() - if vision_token_id is None: - return text_embeds - - vision_mask = (input_ids == vision_token_id).unsqueeze(-1).expand_as(text_embeds) - vision_embeds = vision_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype) - vision_mask = vision_mask.to(device=text_embeds.device) + if self.config is not None: + return getattr(self.config, 'image_token_id', None) + else: + return self.processor.encode(self.image_placeholder) - return text_embeds.masked_scatter(vision_mask, vision_embeds) + def _post_encode(self, model: 'torch.nn.Module', inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs - def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional['torch.Tensor']: - """Get position_ids. Override for models with special position encoding.""" - return None + def pre_forward_hook(self, model: 'torch.nn.Module', args, kwargs): + if not self.is_mm: + return args, kwargs + device = next(model.parameters()).device + old_kwargs = to_device(kwargs, device) + kwargs = to_device(self._post_encode(model, old_kwargs), device) + for k, v in old_kwargs.items(): + if k in { + 'input_ids', 'attention_mask', 'labels', 'position_ids', 'output_hidden_states', 'logits_to_keep', + 'max_length_q', 'max_length_k', 'cu_seq_lens_q', 'cu_seq_lens_k' + } and k not in kwargs: + kwargs[k] = v + if 'inputs_embeds' in kwargs: + kwargs.pop('input_ids', None) + + from peft import PeftModel + if isinstance(model, PeftModel): + base_model = model.model + else: + base_model = model + parameters = inspect.signature(base_model.forward).parameters + if 'position_ids' not in parameters: + kwargs.pop('position_ids', None) + return args, kwargs diff --git a/src/twinkle/template/qwen3_5_vl.py b/src/twinkle/template/qwen3_5_vl.py new file mode 100644 index 00000000..b95877a6 --- /dev/null +++ b/src/twinkle/template/qwen3_5_vl.py @@ -0,0 +1,84 @@ +import torch +from PIL import Image +from typing import Any, Dict, List, Optional, Union + +from twinkle import remote_class, requires +from twinkle.template import Template +from twinkle.template.base import ImageInput, VideoInput +from twinkle.template.utils import get_inputs_embeds_hf + + +@remote_class() +class Qwen3_5Template(Template): + """ + Processor for Qwen VL series. + + Note: Qwen3-VL handles embedding merge internally in forward(), + so post_encode just passes through inputs unchanged. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._patch_size: Optional[int] = None + self._merge_size: Optional[int] = None + self._init_vision_config() + + def _init_vision_config(self): + """Initialize vision config from processor.""" + if hasattr(self.processor, 'image_processor'): + ip = self.processor.image_processor + self._patch_size = getattr(ip, 'patch_size', 16) + self._merge_size = getattr(ip, 'merge_size', 2) + + @property + def patch_size(self) -> int: + """Vision transformer patch size.""" + return self._patch_size or 16 + + @property + def merge_size(self) -> int: + """Spatial merge size for vision tokens.""" + return self._merge_size or 2 + + def preprocess_image(self, image: ImageInput) -> Image.Image: + requires('qwen_vl_utils') + from qwen_vl_utils.vision_process import fetch_image + image = super().preprocess_image(image) + if isinstance(image, str): + image_input = {'image': image} + elif isinstance(image, Image.Image): + image_input = {'image': image} + else: + # Fallback to base class for tensor inputs + return super().preprocess_image(image) + + # Use qwen_vl_utils with correct patch_size + return fetch_image(image_input, image_patch_size=self.patch_size) + + def preprocess_video(self, video: VideoInput) -> Union[List[Image.Image], torch.Tensor]: + requires('qwen_vl_utils') + from qwen_vl_utils.vision_process import fetch_video + + if isinstance(video, str): + video_input = {'video': video} + result = fetch_video(video_input, image_patch_size=self.patch_size, return_video_sample_fps=False) + return result + elif isinstance(video, list): + return [self.preprocess_image(frame) for frame in video] + else: + return super().preprocess_video(video) + + def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: + input_ids = inputs['input_ids'] + from peft import PeftModel + if isinstance(model, PeftModel): + base_model = model.model + else: + base_model = model + if hasattr(base_model.model, 'embed_tokens'): + inputs_embeds = base_model.model.embed_tokens(input_ids) + else: + inputs_embeds = base_model.model.language_model.embed_tokens(input_ids) + inputs_embeds = get_inputs_embeds_hf(inputs_embeds, inputs, base_model.model.visual, self.processor, + model.config) + return {'inputs_embeds': inputs_embeds} diff --git a/src/twinkle/template/qwen3_vl.py b/src/twinkle/template/qwen3_vl.py deleted file mode 100644 index 325d028a..00000000 --- a/src/twinkle/template/qwen3_vl.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -from PIL import Image -from typing import Any, Dict, List, Optional, Union - -from twinkle import remote_class -from twinkle.template import Template -from twinkle.template.base import ImageInput, VideoInput - - -@remote_class() -class Qwen3VLTemplate(Template): - """ - Processor for Qwen VL series. - - Note: Qwen3-VL handles embedding merge internally in forward(), - so post_encode just passes through inputs unchanged. - """ - - def __init__(self, *args, **kwargs): - # TODO untested code - super().__init__(*args, **kwargs) - # Cache processor config for preprocessing - self._patch_size: Optional[int] = None - self._merge_size: Optional[int] = None - self._init_vision_config() - - def _init_vision_config(self): - """Initialize vision config from processor.""" - if hasattr(self.processor, 'image_processor'): - ip = self.processor.image_processor - self._patch_size = getattr(ip, 'patch_size', 16) - self._merge_size = getattr(ip, 'merge_size', 2) - - @property - def patch_size(self) -> int: - """Vision transformer patch size.""" - return self._patch_size or 16 - - @property - def merge_size(self) -> int: - """Spatial merge size for vision tokens.""" - return self._merge_size or 2 - - def preprocess_image(self, image: ImageInput) -> Image.Image: - try: - from qwen_vl_utils.vision_process import fetch_image - if isinstance(image, str): - image_input = {'image': image} - elif isinstance(image, Image.Image): - image_input = {'image': image} - else: - # Fallback to base class for tensor inputs - return super().preprocess_image(image) - - # Use qwen_vl_utils with correct patch_size - return fetch_image(image_input, image_patch_size=self.patch_size) - - except ImportError: - return super().preprocess_image(image) - - def preprocess_video(self, video: VideoInput) -> Union[List[Image.Image], torch.Tensor]: - try: - from qwen_vl_utils.vision_process import fetch_video - - if isinstance(video, str): - # Use qwen_vl_utils for video loading - video_input = {'video': video} - result = fetch_video(video_input, image_patch_size=self.patch_size, return_video_sample_fps=False) - return result - elif isinstance(video, list): - # List of images - preprocess each frame - return [self.preprocess_image(frame) for frame in video] - else: - return super().preprocess_video(video) - - except ImportError: - return super().preprocess_video(video) - - # _build_messages: Uses base class implementation. - # Qwen's HF processor accepts the standard format: - # [{'role': 'user', 'content': [{'type': 'image'}, {'type': 'text', 'text': '...'}]}] - - def post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Qwen3-VL handles embedding merge internally.""" - return inputs - - def _get_vision_token_id(self) -> Optional[int]: - if self.config is not None: - return getattr(self.config, 'image_token_id', None) - return None - - def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional[torch.Tensor]: - """Get 3D RoPE position_ids for Qwen VL.""" - if self.model is None: - return None - - input_ids = inputs.get('input_ids') - if input_ids is None: - return None - - # Find get_rope_index - base_model = self.model - if hasattr(base_model, 'base_model'): - base_model = base_model.base_model - if hasattr(base_model, 'model'): - base_model = base_model.model - - get_rope_index = getattr(base_model, 'get_rope_index', None) - if get_rope_index is None and hasattr(base_model, 'model'): - get_rope_index = getattr(base_model.model, 'get_rope_index', None) - - if get_rope_index is None: - return None - - try: - position_ids, _ = get_rope_index(input_ids, inputs.get('image_grid_thw'), inputs.get('video_grid_thw'), - inputs.get('attention_mask')) - return position_ids - except Exception: - return None diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index 2bea8f22..ccdcfc91 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -1,13 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import inspect from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar from twinkle.data_format import Message, Trajectory +from twinkle.utils import to_device if TYPE_CHECKING: from transformers import PreTrainedTokenizer +_T = TypeVar('_T') + PLACEHOLDER = '<<>>' @@ -101,14 +104,14 @@ def tokenize_with_assistant_labels( assistant_count += 1 _dummy_messages.append(msg) - encoded = encode_func(trajectory, ) + encoded = encode_func(trajectory) full_ids = encoded.pop('input_ids') if isinstance(full_ids, torch.Tensor): full_ids = full_ids.tolist()[0] _dummy_trajectory = copy(trajectory) _dummy_trajectory['messages'] = _dummy_messages - template_ids = encode_func(_dummy_trajectory, ) + template_ids = encode_func(_dummy_trajectory) template_ids = template_ids['input_ids'] if isinstance(template_ids, torch.Tensor): template_ids = template_ids.tolist()[0] @@ -168,50 +171,66 @@ def _load_image(img: Any) -> Optional[Any]: return img -def _transfer_single_message(content: str, image_placeholder, video_placeholder, images, videos): - image_idx = 0 - video_idx = 0 - remaining = content - # Handle None images/videos - images = images or [] - videos = videos or [] - has_image = image_placeholder in content - has_video = video_placeholder in content - new_content = [] - while remaining: - img_pos = remaining.find(image_placeholder) if has_image else -1 - vid_pos = remaining.find(video_placeholder) if has_video else -1 - - # Find next placeholder - if img_pos == -1 and vid_pos == -1: - if remaining.strip(): - new_content.append({'type': 'text', 'text': remaining}) - break +def _transfer_single_message( + content: str, + image_placeholder: str, + video_placeholder: str, + audio_placeholder: str, + images: list | None = None, + videos: list | None = None, + audios: list | None = None, +) -> list[dict]: + if not content: + return [] + + media_configs = [ + (image_placeholder, 'image', images or []), + (video_placeholder, 'video', videos or []), + (audio_placeholder, 'audio', audios or []), + ] + + placeholders = [] + for placeholder, media_type, media_list in media_configs: + if not placeholder: + continue + start = 0 + media_idx = 0 + while (pos := content.find(placeholder, start)) != -1: + url = media_list[media_idx] if media_idx < len(media_list) else None + placeholders.append((pos, len(placeholder), media_type, url)) + media_idx += 1 + start = pos + len(placeholder) - # Determine which comes first - if vid_pos == -1 or (img_pos != -1 and img_pos < vid_pos): - # Image placeholder - if remaining[:img_pos].strip(): - new_content.append({'type': 'text', 'text': remaining[:img_pos]}) - if image_idx < len(images): - new_content.append({'type': 'image', 'url': images[image_idx]}) - image_idx += 1 - remaining = remaining[img_pos + len(image_placeholder):] - else: - # Video placeholder - if remaining[:vid_pos].strip(): - new_content.append({'type': 'text', 'text': remaining[:vid_pos]}) - if video_idx < len(videos): - new_content.append({'type': 'video', 'url': videos[video_idx]}) - video_idx += 1 - remaining = remaining[vid_pos + len(video_placeholder):] - return new_content + if not placeholders: + return [{'type': 'text', 'text': content}] if content.strip() else [] + + placeholders.sort(key=lambda x: x[0]) + + result = [] + cursor = 0 + + for pos, length, media_type, url in placeholders: + text_segment = content[cursor:pos] + if text_segment.strip(): + result.append({'type': 'text', 'text': text_segment}) + + if url is not None: + result.append({'type': media_type, 'url': url}) + cursor = pos + length -def transfer_to_standard_message(message: Message, image_placeholder, video_placeholder, is_mm): + trailing_text = content[cursor:] + if trailing_text.strip(): + result.append({'type': 'text', 'text': trailing_text}) + + return result + + +def transfer_to_standard_message(message: Message, image_placeholder, video_placeholder, audio_placeholder, is_mm): if is_mm: new_content = _transfer_single_message(message['content'], image_placeholder, video_placeholder, - message.get('images'), message.get('videos')) + audio_placeholder, message.get('images'), message.get('videos'), + message.get('audios')) else: new_content = message['content'] @@ -220,3 +239,61 @@ def transfer_to_standard_message(message: Message, image_placeholder, video_plac content=new_content, tool_calls=message.get('tool_calls'), reasoning_content=message.get('reasoning_content')) + + +def get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config): + input_ids = inputs['input_ids'] + pixel_values = inputs.get('pixel_values') + pixel_values_videos = inputs.get('pixel_values_videos') + image_grid_thw = inputs.get('image_grid_thw') + video_grid_thw = inputs.get('video_grid_thw') + dtype = visual.dtype + if pixel_values is None and pixel_values_videos is None: + from PIL import Image + images = [Image.new('RGB', (32, 32), (0, 0, 0))] + media_inputs = processor.image_processor(images=images, return_tensors='pt') + media_inputs = to_device(media_inputs, input_ids.device) + pixel_values = media_inputs['pixel_values'].type(dtype) + image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) + if hasattr(image_embeds, 'pooler_output'): + image_embeds = image_embeds.pooler_output + inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0. + else: + import torch + if pixel_values is None: + pixel_values_mixed = pixel_values_videos + grid_thw = video_grid_thw + elif pixel_values_videos is None: + pixel_values_mixed = pixel_values + grid_thw = image_grid_thw + else: + pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0) + grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0) + pixel_values_mixed = pixel_values_mixed.type(dtype) + mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw) + if hasattr(mixed_embeds, 'pooler_output'): + mixed_embeds = mixed_embeds.pooler_output + if pixel_values is None: + image_embeds = None + video_embeds = mixed_embeds + elif pixel_values_videos is None: + image_embeds = mixed_embeds + video_embeds = None + else: + merge_length = processor.image_processor.merge_size**2 + image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() + image_embeds = mixed_embeds[:image_tokens] + video_embeds = mixed_embeds[image_tokens:] + + if image_embeds is not None: + image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask = image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if video_embeds is not None: + video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_mask = video_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + return inputs_embeds diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 1d7f9028..1b018773 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -14,3 +14,4 @@ from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr +from .vision_tools import load_image, load_mm_file diff --git a/src/twinkle/utils/vision_tools.py b/src/twinkle/utils/vision_tools.py new file mode 100644 index 00000000..5638fca0 --- /dev/null +++ b/src/twinkle/utils/vision_tools.py @@ -0,0 +1,54 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import base64 +import os +import re +import requests +from io import BytesIO +from requests.adapters import HTTPAdapter +from typing import TYPE_CHECKING, TypeVar, Union +from urllib3.util.retry import Retry + +if TYPE_CHECKING: + from PIL import Image + +_T = TypeVar('_T') + + +def load_mm_file(path: Union[str, bytes, _T]) -> Union[BytesIO, _T]: + res = path + if isinstance(path, str): + path = path.strip() + if path.startswith('http'): + retries = Retry(total=3, backoff_factor=1, allowed_methods=['GET']) + with requests.Session() as session: + session.mount('http://', HTTPAdapter(max_retries=retries)) + session.mount('https://', HTTPAdapter(max_retries=retries)) + response = session.get(path, timeout=10) + response.raise_for_status() + content = response.content + res = BytesIO(content) + else: + data = path + if os.path.exists(path): + with open(path, 'rb') as f: + res = BytesIO(f.read()) + else: + if data.startswith('data:'): + match_ = re.match(r'data:(.+?);base64,(.+)', data) + assert match_ is not None + data = match_.group(2) + data = base64.b64decode(data) + res = BytesIO(data) + elif isinstance(path, bytes): + res = BytesIO(path) + return res + + +def load_image(image: Union[str, bytes, 'Image.Image']) -> 'Image.Image': + image = load_mm_file(image) + if isinstance(image, BytesIO): + from PIL import Image + image = Image.open(image) + if image.mode != 'RGB': + image = image.convert('RGB') + return image diff --git a/tests/dataloader/test_multimodal.py b/tests/dataloader/test_multimodal.py index 9b4905bb..0031b150 100644 --- a/tests/dataloader/test_multimodal.py +++ b/tests/dataloader/test_multimodal.py @@ -28,9 +28,9 @@ def test_dataloader_multimodal_with_lazy_dataset(self): dataset.map(create_multimodal_messages) try: - dataset.set_template('Qwen3VLTemplate', model_id='Qwen/Qwen2-VL-7B-Instruct') + dataset.set_template('Qwen3_5Template', model_id='Qwen/Qwen2-VL-7B-Instruct') except Exception as e: - pytest.skip(f'Failed to load Qwen3VLTemplate (may need network): {e}') + pytest.skip(f'Failed to load Qwen3_5Template (may need network): {e}') dataset.encode() diff --git a/tests/dataset/test_multimodal.py b/tests/dataset/test_multimodal.py index 31e92239..5fca8f4e 100644 --- a/tests/dataset/test_multimodal.py +++ b/tests/dataset/test_multimodal.py @@ -37,15 +37,15 @@ def test_multimodal_dataset_basic(self): @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download') def test_multimodal_dataset_with_qwen3vl_template(self): - # Use Qwen3VLTemplate + # Use Qwen3_5Template csv_path = str(TEST_DATA_DIR / 'test.csv') dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path)) dataset.map(create_multimodal_messages) try: - dataset.set_template('Qwen3VLTemplate', model_id='Qwen/Qwen3-VL-2B-Instruct') + dataset.set_template('Qwen3_5Template', model_id='Qwen/Qwen3-VL-2B-Instruct') except Exception as e: - pytest.skip(f'Failed to load Qwen3VLTemplate (may need network): {e}') + pytest.skip(f'Failed to load Qwen3_5Template (may need network): {e}') assert dataset.template is not None assert hasattr(dataset.template, 'is_mm') @@ -58,9 +58,9 @@ def test_multimodal_dataset_encode_with_lazy(self): dataset.map(create_multimodal_messages) try: - dataset.set_template('Qwen3VLTemplate', model_id='Qwen/Qwen3-VL-2B-Instruct') + dataset.set_template('Qwen3_5Template', model_id='Qwen/Qwen3-VL-2B-Instruct') except Exception as e: - pytest.skip(f'Failed to load Qwen3VLTemplate (may need network): {e}') + pytest.skip(f'Failed to load Qwen3_5Template (may need network): {e}') try: dataset.encode() diff --git a/tests/sampler/align_swift.py b/tests/sampler/align_swift.py index b9bcec77..dc33ff36 100644 --- a/tests/sampler/align_swift.py +++ b/tests/sampler/align_swift.py @@ -17,7 +17,6 @@ import gc import os -import sys import torch from swift.infer_engine import RequestConfig, TransformersEngine, VllmEngine from swift.utils import seed_everything @@ -27,8 +26,7 @@ from twinkle.data_format import SamplingParams, Trajectory from twinkle.sampler.torch_sampler import TorchSampler from twinkle.sampler.vllm_sampler import vLLMSampler -from twinkle.template import Template -from twinkle.template.qwen3_vl import Qwen3VLTemplate +from twinkle.template import Qwen3_5Template, Template # Test models LLM_MODEL_ID = 'Qwen/Qwen2.5-7B-Instruct' @@ -182,7 +180,6 @@ def test_llm_vllm_sampler_ray(): from peft import LoraConfig from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger - from twinkle.checkpoint_engine import CheckpointEngineManager from twinkle.model import TransformersModel from twinkle.processor import InputProcessor @@ -269,7 +266,7 @@ def test_mllm_torch_sampler(): seed_everything(42) from transformers import Qwen3VLForConditionalGeneration sampler = TorchSampler(MLLM_MODEL_ID, model_cls=Qwen3VLForConditionalGeneration) - sampler.set_template(Qwen3VLTemplate, model_id=MLLM_MODEL_ID) + sampler.set_template(Qwen3_5Template, model_id=MLLM_MODEL_ID) trajectory = Trajectory(messages=MLLM_MESSAGES, images=MLLM_IMAGES) sampling_params = SamplingParams(max_tokens=128, temperature=0) @@ -297,7 +294,7 @@ def test_mllm_vllm_sampler(): seed_everything(42) sampler = vLLMSampler(MLLM_MODEL_ID, gpu_memory_utilization=VLLM_GPU_MEM, max_model_len=VLLM_MAX_MODEL_LEN) - sampler.set_template(Qwen3VLTemplate, model_id=MLLM_MODEL_ID) + sampler.set_template(Qwen3_5Template, model_id=MLLM_MODEL_ID) trajectory = Trajectory(messages=MLLM_MESSAGES, images=MLLM_IMAGES) sampling_params = SamplingParams(max_tokens=128, temperature=0)