From 8590c318c1b798e6fd5b515235caaa1929b1b93e Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 3 Mar 2026 21:54:08 +0800 Subject: [PATCH 01/13] support mm --- README.md | 2 - README_ZH.md | 2 - src/twinkle/model/multi_lora.py | 4 +- .../transformers/multi_lora_transformers.py | 20 +++ .../model/transformers/transformers.py | 13 ++ src/twinkle/template/base.py | 151 ++++++----------- src/twinkle/template/qwen3_vl.py | 27 +-- src/twinkle/template/utils.py | 154 +++++++++++++----- 8 files changed, 212 insertions(+), 161 deletions(-) diff --git a/README.md b/README.md index cc92fc2f..d0693836 100644 --- a/README.md +++ b/README.md @@ -135,8 +135,6 @@ supported on Twinkle✨ framework. | | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) | | deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 1.5B/7B/14B/32B | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | -For more detailed model support list 👉 [Quick Start](docs/source_en/Usage%20Guide/Quick-Start.md) - ## Sample Code Below are some of the capabilities demonstrated in the example code. For a complete introduction to training capabilities, diff --git a/README_ZH.md b/README_ZH.md index 10c1fc88..11a6cccc 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -116,8 +116,6 @@ Twinkle✨支持相同的算法接口运行在单GPU、torchrun多机、Ray、Cl | | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | - | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) | | deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | 1.5B/7B/14B/32B | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | -更详细的模型支持列表 👉 [快速开始.md](docs/source_zh/使用指引/快速开始.md) - ## 示例代码 下面列出了示例代码的一部分能力。完整的训练能力介绍请参考[快速开始](docs/source_zh/使用指引/快速开始.md)以及[cookbook](cookbook)。 diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index dc16f624..9780973b 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -370,7 +370,9 @@ def _patch_peft(_module): if isinstance(_module, PeftModel): _module.add_adapter(lora_tenant.adapter_name, config) else: - _module = get_peft_model(_module, config, lora_tenant.adapter_name) + _peft_model: PeftModel = get_peft_model(_module, config, lora_tenant.adapter_name) + _module.active_adapters = _peft_model.active_adapters + _module = _peft_model for name, submodule in _module.named_modules(): if isinstance(submodule, LoraLayer): diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 4386cc82..c638d06e 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from peft import LoraConfig, PeftConfig, PeftModel, load_peft_weights +from sympy.printing.pytorch import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel @@ -57,11 +58,30 @@ def __init__( self.multi_adapter.save_initial_weights() # Active group for compatibility with single adapter self.active_group = None + self.handler = self.register_global_mm_forward_hook() def _check_adapter_valid(self, adapter_name: str): assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, ' f'current is: {adapter_name}') + def register_global_mm_forward_hook(self): + + def forward_hook(model: torch.nn.Module, args, kwargs): + active_adapter = model.active_adapters[0] + optimizer_group = self.optimizer_group[active_adapter] + template = optimizer_group.template + assert template is not None + return template.pre_forward_hook(model, args, kwargs) + + model = self.strategy.unwrap_model(self.model) + return model.register_forward_hook(forward_hook) + + def register_mm_forward_hook(self, optimizer_group: OptimizerGroup): + pass + + def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup): + pass + def _lazy_wrap_model(self): pass diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 33f044af..e6fbc397 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -62,6 +62,7 @@ class OptimizerGroup: checkpoint_engine: CheckpointEngine = None _dp_group = None _device_mesh: DeviceMesh = None + _handler: Any = None def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool: if gradient_accumulation_steps is None: @@ -284,11 +285,23 @@ def _lazy_wrap_model(self): assert optimizer is not None self.model, optimizer = self.strategy.wrap_model(self.model, optimizer) optimizer_group.optimizer = optimizer + self.register_mm_forward_hook(optimizer_group) else: # maybe forward_only, no optimizer_group available self.model = self.strategy.wrap_model(self.model) self._model_wrapped = True + def register_mm_forward_hook(self, optimizer_group: OptimizerGroup): + model = self.strategy.unwrap_model(self.model) + template = optimizer_group.template + assert template is not None + optimizer_group._handler = model.register_forward_pre_hook(template.pre_forward_hook, with_kwargs=True) + + def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup): + if optimizer_group._handler is not None: + optimizer_group._handler.remove() + optimizer_group._handler = None + @staticmethod def _should_enable_expert_parallel(expert_parallel_config: Optional[Dict[str, Any]], device_mesh: Optional[DeviceMesh]) -> bool: diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 488d53a8..e73efd77 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -1,13 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import inspect + import numpy as np import os from collections.abc import Mapping -from copy import deepcopy +from copy import deepcopy, copy from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union from twinkle.data_format import InputFeature, Message, Trajectory from twinkle.hub import HubOperation from .utils import tokenize_with_assistant_labels, transfer_to_standard_message +from twinkle.utils import to_device if TYPE_CHECKING: import torch @@ -40,6 +43,8 @@ def __init__(self, else: from transformers import AutoTokenizer self.processor = AutoTokenizer.from_pretrained(model_id, **kwargs) + from transformers import AutoConfig + self.config = AutoConfig.from_pretrained(model_id, **kwargs) self.use_chat_template = use_chat_template self.max_length = max_length @@ -205,59 +210,26 @@ def _roll_labels(self, input_feature: InputFeature) -> List[InputFeature]: return [input_feature] def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: - # TODO code untested messages = trajectory['messages'] - # Get images/videos from trajectory level (common case) or message level - traj_images = trajectory.get('images') or [] - traj_videos = trajectory.get('videos') or [] - - # Preprocess all trajectory-level images and videos - if traj_images and self.is_mm: - traj_images = self.preprocess_images(traj_images) - if traj_videos and self.is_mm: - traj_videos = self.preprocess_videos(traj_videos) - - # Distribute trajectory-level images to messages that contain placeholders - image_idx = 0 - video_idx = 0 new_messages = [] for message in messages: - # If message already has images/videos at message level, use those + message = copy(message) + content = message['content'] msg_images = message.get('images') msg_videos = message.get('videos') - - # If not, assign from trajectory level based on placeholder count - if msg_images is None and self.is_mm: - content = message.get('content', '') - if isinstance(content, str): - placeholder_count = content.count(self.image_placeholder) - if placeholder_count > 0 and image_idx < len(traj_images): - msg_images = traj_images[image_idx:image_idx + placeholder_count] - image_idx += placeholder_count - elif msg_images and self.is_mm: - # Preprocess message-level images - msg_images = self.preprocess_images(msg_images) - - if msg_videos is None and self.is_mm: - content = message.get('content', '') - if isinstance(content, str): - placeholder_count = content.count(self.video_placeholder) - if placeholder_count > 0 and video_idx < len(traj_videos): - msg_videos = traj_videos[video_idx:video_idx + placeholder_count] - video_idx += placeholder_count - elif msg_videos and self.is_mm: - # Preprocess message-level videos - msg_videos = self.preprocess_videos(msg_videos) - - # Create message with images/videos attached - msg_with_media = dict(message) + msg_audios = message.get('audios') if msg_images: - msg_with_media['images'] = msg_images + message['images'] = self.preprocess_images(msg_images) + assert len(message['images']) == content.count(self.image_placeholder) if msg_videos: - msg_with_media['videos'] = msg_videos - + message['videos'] = self.preprocess_images(msg_videos) + assert len(message['videos']) == content.count(self.video_placeholder) + if msg_audios: + message['audios'] = self.preprocess_audios(msg_audios) + assert len(message['audios']) == content.count(self.audio_placeholder) new_messages.append( - transfer_to_standard_message(msg_with_media, self.image_placeholder, self.video_placeholder, + transfer_to_standard_message(message, self.image_placeholder, self.video_placeholder, + self.audio_placeholder, self.is_mm)) trajectory['messages'] = new_messages @@ -377,65 +349,38 @@ def decode(self, token_ids: List[int], **kwargs) -> str: def batch_decode(self, token_ids: List[List[int]], **kwargs) -> List[str]: return [self.processor.decode(_ids, **kwargs) for _ids in token_ids] - def post_encode(self, model: 'torch.nn.Module', inputs: Dict[str, Any]) -> Dict[str, Any]: - """ - Transform inputs for model forward. - - Default: use helper methods for embedding merge. - Override if model handles internally (like Qwen3-VL). - """ - input_ids = inputs.get('input_ids') - if input_ids is None: - return inputs - - text_embeds = self._get_text_embeddings(model, input_ids) - vision_embeds = self._get_vision_embeddings(model, inputs) - - if vision_embeds is not None: - inputs_embeds = self._merge_vision_embeddings(text_embeds, vision_embeds, input_ids, inputs) - else: - inputs_embeds = text_embeds - - result = {k: v for k, v in inputs.items() if k != 'input_ids'} - result['inputs_embeds'] = inputs_embeds - return result - - def _get_text_embeddings(self, model: 'torch.nn.Module', input_ids: 'torch.Tensor') -> 'torch.Tensor': - """Get text embeddings from model.""" - embed_fn = None - if hasattr(model, 'get_input_embeddings'): - embed_fn = model.get_input_embeddings() - elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'): - embed_fn = model.model.embed_tokens - elif hasattr(model, 'language_model') and hasattr(model.language_model, 'embed_tokens'): - embed_fn = model.language_model.embed_tokens - - if embed_fn is None: - raise ValueError('Cannot find embedding layer in model') - - return embed_fn(input_ids) - - def _get_vision_embeddings(self, model: 'torch.nn.Module', inputs: Dict[str, Any]) -> Optional['torch.Tensor']: - """Get vision embeddings. Override in subclass.""" - return None - def _get_vision_token_id(self) -> Optional[int]: - """Get vision placeholder token ID. Override in subclass.""" - return self.processor.encode(self.image_placeholder) - - def _merge_vision_embeddings(self, text_embeds: 'torch.Tensor', vision_embeds: 'torch.Tensor', - input_ids: 'torch.Tensor', inputs: Dict[str, Any]) -> 'torch.Tensor': - """Merge vision embeddings at placeholder positions.""" - vision_token_id = self._get_vision_token_id() - if vision_token_id is None: - return text_embeds - - vision_mask = (input_ids == vision_token_id).unsqueeze(-1).expand_as(text_embeds) - vision_embeds = vision_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype) - vision_mask = vision_mask.to(device=text_embeds.device) - - return text_embeds.masked_scatter(vision_mask, vision_embeds) + if self.config is not None: + return getattr(self.config, 'image_token_id', None) + else: + return self.processor.encode(self.image_placeholder) def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional['torch.Tensor']: """Get position_ids. Override for models with special position encoding.""" return None + + def _post_encode(self, model: torch.nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + def pre_forward_hook(self, model: torch.nn.Module, args, kwargs): + device = next(model.parameters()).device + old_kwargs = to_device(kwargs, device) + kwargs = to_device(self._post_encode(model, old_kwargs), device) + for k, v in old_kwargs.items(): + if k in { + 'input_ids', 'attention_mask', 'labels', 'position_ids', 'output_hidden_states', 'logits_to_keep', + 'max_length_q', 'max_length_k', 'cu_seq_lens_q', 'cu_seq_lens_k' + } and k not in kwargs: + kwargs[k] = v + if 'inputs_embeds' in kwargs: + kwargs.pop('input_ids', None) + + from peft import PeftModel + if isinstance(model, PeftModel): + base_model = model.model + else: + base_model = model + parameters = inspect.signature(base_model.forward).parameters + if 'position_ids' not in parameters: + kwargs.pop('position_ids', None) + return args, kwargs diff --git a/src/twinkle/template/qwen3_vl.py b/src/twinkle/template/qwen3_vl.py index 325d028a..ff9ea711 100644 --- a/src/twinkle/template/qwen3_vl.py +++ b/src/twinkle/template/qwen3_vl.py @@ -1,10 +1,10 @@ import torch from PIL import Image from typing import Any, Dict, List, Optional, Union - from twinkle import remote_class from twinkle.template import Template from twinkle.template.base import ImageInput, VideoInput +from twinkle.template.utils import get_inputs_embeds_hf @remote_class() @@ -76,18 +76,19 @@ def preprocess_video(self, video: VideoInput) -> Union[List[Image.Image], torch. except ImportError: return super().preprocess_video(video) - # _build_messages: Uses base class implementation. - # Qwen's HF processor accepts the standard format: - # [{'role': 'user', 'content': [{'type': 'image'}, {'type': 'text', 'text': '...'}]}] - - def post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Qwen3-VL handles embedding merge internally.""" - return inputs - - def _get_vision_token_id(self) -> Optional[int]: - if self.config is not None: - return getattr(self.config, 'image_token_id', None) - return None + def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: + input_ids = inputs['input_ids'] + from peft import PeftModel + if isinstance(model, PeftModel): + base_model = model.model + else: + base_model = model + if hasattr(base_model.model, 'embed_tokens'): + inputs_embeds = base_model.model.embed_tokens(input_ids) + else: + inputs_embeds = base_model.model.language_model.embed_tokens(input_ids) + inputs_embeds = get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config) + return {'inputs_embeds': inputs_embeds} def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional[torch.Tensor]: """Get 3D RoPE position_ids for Qwen VL.""" diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index 2bea8f22..126ae970 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from twinkle.data_format import Message, Trajectory +from twinkle.utils import to_device if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -168,50 +169,65 @@ def _load_image(img: Any) -> Optional[Any]: return img -def _transfer_single_message(content: str, image_placeholder, video_placeholder, images, videos): - image_idx = 0 - video_idx = 0 - remaining = content - # Handle None images/videos - images = images or [] - videos = videos or [] - has_image = image_placeholder in content - has_video = video_placeholder in content - new_content = [] - while remaining: - img_pos = remaining.find(image_placeholder) if has_image else -1 - vid_pos = remaining.find(video_placeholder) if has_video else -1 - - # Find next placeholder - if img_pos == -1 and vid_pos == -1: - if remaining.strip(): - new_content.append({'type': 'text', 'text': remaining}) - break +def _transfer_single_message( + content: str, + image_placeholder: str, + video_placeholder: str, + audio_placeholder: str, + images: list | None = None, + videos: list | None = None, + audios: list | None = None, +) -> list[dict]: + if not content: + return [] + + media_configs = [ + (image_placeholder, 'image', images or []), + (video_placeholder, 'video', videos or []), + (audio_placeholder, 'audio', audios or []), + ] + + placeholders = [] + for placeholder, media_type, media_list in media_configs: + if not placeholder: + continue + start = 0 + media_idx = 0 + while (pos := content.find(placeholder, start)) != -1: + url = media_list[media_idx] if media_idx < len(media_list) else None + placeholders.append((pos, len(placeholder), media_type, url)) + media_idx += 1 + start = pos + len(placeholder) - # Determine which comes first - if vid_pos == -1 or (img_pos != -1 and img_pos < vid_pos): - # Image placeholder - if remaining[:img_pos].strip(): - new_content.append({'type': 'text', 'text': remaining[:img_pos]}) - if image_idx < len(images): - new_content.append({'type': 'image', 'url': images[image_idx]}) - image_idx += 1 - remaining = remaining[img_pos + len(image_placeholder):] - else: - # Video placeholder - if remaining[:vid_pos].strip(): - new_content.append({'type': 'text', 'text': remaining[:vid_pos]}) - if video_idx < len(videos): - new_content.append({'type': 'video', 'url': videos[video_idx]}) - video_idx += 1 - remaining = remaining[vid_pos + len(video_placeholder):] - return new_content + if not placeholders: + return [{'type': 'text', 'text': content}] if content.strip() else [] + + placeholders.sort(key=lambda x: x[0]) + + result = [] + cursor = 0 + + for pos, length, media_type, url in placeholders: + text_segment = content[cursor:pos] + if text_segment.strip(): + result.append({'type': 'text', 'text': text_segment}) + + if url is not None: + result.append({'type': media_type, 'url': url}) + + cursor = pos + length + trailing_text = content[cursor:] + if trailing_text.strip(): + result.append({'type': 'text', 'text': trailing_text}) -def transfer_to_standard_message(message: Message, image_placeholder, video_placeholder, is_mm): + return result + + +def transfer_to_standard_message(message: Message, image_placeholder, video_placeholder, audio_placeholder, is_mm): if is_mm: - new_content = _transfer_single_message(message['content'], image_placeholder, video_placeholder, - message.get('images'), message.get('videos')) + new_content = _transfer_single_message(message['content'], image_placeholder, video_placeholder, audio_placeholder, + message.get('images'), message.get('videos'), message.get('audios')) else: new_content = message['content'] @@ -220,3 +236,61 @@ def transfer_to_standard_message(message: Message, image_placeholder, video_plac content=new_content, tool_calls=message.get('tool_calls'), reasoning_content=message.get('reasoning_content')) + + +def get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config): + input_ids = inputs['input_ids'] + pixel_values = inputs.get('pixel_values') + pixel_values_videos = inputs.get('pixel_values_videos') + image_grid_thw = inputs.get('image_grid_thw') + video_grid_thw = inputs.get('video_grid_thw') + dtype = visual.dtype + if pixel_values is None and pixel_values_videos is None: + from PIL import Image + images = [Image.new('RGB', (32, 32), (0, 0, 0))] + media_inputs = processor.image_processor(images=images, return_tensors='pt') + media_inputs = to_device(media_inputs, input_ids.device) + pixel_values = media_inputs['pixel_values'].type(dtype) + image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) + if hasattr(image_embeds, 'pooler_output'): + image_embeds = image_embeds.pooler_output + inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0. + else: + import torch + if pixel_values is None: + pixel_values_mixed = pixel_values_videos + grid_thw = video_grid_thw + elif pixel_values_videos is None: + pixel_values_mixed = pixel_values + grid_thw = image_grid_thw + else: + pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0) + grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0) + pixel_values_mixed = pixel_values_mixed.type(dtype) + mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw) + if hasattr(mixed_embeds, 'pooler_output'): + mixed_embeds = mixed_embeds.pooler_output + if pixel_values is None: + image_embeds = None + video_embeds = mixed_embeds + elif pixel_values_videos is None: + image_embeds = mixed_embeds + video_embeds = None + else: + merge_length = processor.image_processor.merge_size ** 2 + image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() + image_embeds = mixed_embeds[:image_tokens] + video_embeds = mixed_embeds[image_tokens:] + + if image_embeds is not None: + image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask = image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if video_embeds is not None: + video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_mask = video_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + return inputs_embeds From 5203bedd8a29e8162199d2aba10e811eec0ea3f9 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 14:41:22 +0800 Subject: [PATCH 02/13] fix --- src/twinkle/hub/hub.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/twinkle/hub/hub.py b/src/twinkle/hub/hub.py index 899de321..701505e0 100644 --- a/src/twinkle/hub/hub.py +++ b/src/twinkle/hub/hub.py @@ -165,7 +165,7 @@ def load_dataset(cls, subset_name: str, split: str, streaming: bool = False, - revision: Optional[str] = None): + revision: Optional[str] = None, **kwargs): """Load a dataset from the repo Args: @@ -179,7 +179,7 @@ def load_dataset(cls, The Dataset instance """ hub = cls._get_hub_class(dataset_id) - return hub.load_dataset(cls.remove_source_type(dataset_id), subset_name, split, streaming, revision) + return hub.load_dataset(cls.remove_source_type(dataset_id), subset_name, split, streaming, revision, **kwargs) @classmethod def download_model(cls, From e7d84ceaef5bcacb3337b9e7d71321cc76ed2f34 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 15:37:40 +0800 Subject: [PATCH 03/13] wip --- cookbook/mm/fsdp2.py | 83 +++++++++++++++++++++++++++++++ src/twinkle/template/base.py | 3 +- src/twinkle/template/qwen3_vl.py | 5 +- src/twinkle/template/utils.py | 4 +- src/twinkle/utils/__init__.py | 1 + src/twinkle/utils/vision_tools.py | 54 ++++++++++++++++++++ 6 files changed, 147 insertions(+), 3 deletions(-) create mode 100644 cookbook/mm/fsdp2.py create mode 100644 src/twinkle/utils/vision_tools.py diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py new file mode 100644 index 00000000..bb3300b6 --- /dev/null +++ b/cookbook/mm/fsdp2.py @@ -0,0 +1,83 @@ +from peft import LoraConfig +from tqdm import tqdm + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.data_format import Trajectory, Message +from twinkle.dataloader import DataLoader +from twinkle.dataset import LazyDataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor, Preprocessor + +# Construct a device_mesh, fsdp=4, dp=2 +device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2) +# use torchrun mode +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +logger = get_logger() + + +class LatexOCRProcessor(Preprocessor): + + def __call__(self, row) -> Trajectory: + return Trajectory( + messages=[ + Message(role='user', content='Using LaTeX to perform OCR on the image.', images=[row['image']]), + Message(role='assistant', content=row['text']), + ] + ) + + +def train(): + # 2000 samples + dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000))) + # Set template to prepare encoding + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + # Preprocess the dataset to standard format + dataset.map(LatexOCRProcessor) + # Encode dataset + dataset.encode() + # Global batch size = 8, for GPUs, so 1 sample per GPU + dataloader = DataLoader(dataset=dataset, batch_size=8) + # Use a TransformersModel + model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model.model._no_split_modules = {'Qwen3_5DecoderLayer'} + + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + + # Add a lora to model, with name `default` + # Comment this to use full-parameter training + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + # Add Optimizer for lora `default` + model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + # Add LRScheduler for lora `default` + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + logger.info(get_device_placement()) + # Print the training config + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + loss_metric = 99.0 + # lora: 18G * 4 + # full: 50G * 4 + for step, batch in enumerate(dataloader): + # Do forward and backward + model.forward_backward(inputs=batch) + # Step + model.clip_grad_and_step() + if step % 20 == 0: + # Print metric + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + if step > 0 and step % 40 == 0: + metrics = eval(model) + logger.info(f'Eval metric: {metrics}') + metrics['step'] = step + if loss_metric > float(metrics['loss']): + model.save(f'checkpoint-{step}') + loss_metric = float(metrics['loss']) + model.save(f'last-checkpoint') + + +if __name__ == '__main__': + train() diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index e73efd77..3751992c 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -11,6 +11,7 @@ from twinkle.hub import HubOperation from .utils import tokenize_with_assistant_labels, transfer_to_standard_message from twinkle.utils import to_device +from twinkle.utils import load_image if TYPE_CHECKING: import torch @@ -114,7 +115,7 @@ def _test_support_assistant_tokens_mask(self): self._template_support_assistant_tokens_mask = False def preprocess_image(self, image: ImageInput) -> 'Image.Image': - return image + return load_image(image) def preprocess_video(self, video: VideoInput) -> List['Image.Image']: return video diff --git a/src/twinkle/template/qwen3_vl.py b/src/twinkle/template/qwen3_vl.py index ff9ea711..b679fb80 100644 --- a/src/twinkle/template/qwen3_vl.py +++ b/src/twinkle/template/qwen3_vl.py @@ -1,10 +1,11 @@ import torch from PIL import Image from typing import Any, Dict, List, Optional, Union -from twinkle import remote_class +from twinkle import remote_class, requires from twinkle.template import Template from twinkle.template.base import ImageInput, VideoInput from twinkle.template.utils import get_inputs_embeds_hf +from twinkle.utils import load_image @remote_class() @@ -43,7 +44,9 @@ def merge_size(self) -> int: def preprocess_image(self, image: ImageInput) -> Image.Image: try: + requires('qwen_vl_utils') from qwen_vl_utils.vision_process import fetch_image + image = load_image(image) if isinstance(image, str): image_input = {'image': image} elif isinstance(image, Image.Image): diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index 126ae970..a5416542 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import inspect from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar from twinkle.data_format import Message, Trajectory from twinkle.utils import to_device @@ -9,6 +9,8 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer +_T = TypeVar('_T') + PLACEHOLDER = '<<>>' diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 1d7f9028..6bcd36d4 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -14,3 +14,4 @@ from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr +from .vision_tools import load_mm_file, load_image diff --git a/src/twinkle/utils/vision_tools.py b/src/twinkle/utils/vision_tools.py new file mode 100644 index 00000000..607d34a8 --- /dev/null +++ b/src/twinkle/utils/vision_tools.py @@ -0,0 +1,54 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import base64 +import os +import re +from io import BytesIO +from typing import Union, TypeVar, TYPE_CHECKING + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +if TYPE_CHECKING: + from PIL import Image + +_T = TypeVar('_T') + + +def load_mm_file(path: Union[str, bytes, _T]) -> Union[BytesIO, _T]: + res = path + if isinstance(path, str): + path = path.strip() + if path.startswith('http'): + retries = Retry(total=3, backoff_factor=1, allowed_methods=['GET']) + with requests.Session() as session: + session.mount('http://', HTTPAdapter(max_retries=retries)) + session.mount('https://', HTTPAdapter(max_retries=retries)) + response = session.get(path, timeout=10) + response.raise_for_status() + content = response.content + res = BytesIO(content) + else: + data = path + if os.path.exists(path): + with open(path, 'rb') as f: + res = BytesIO(f.read()) + else: + if data.startswith('data:'): + match_ = re.match(r'data:(.+?);base64,(.+)', data) + assert match_ is not None + data = match_.group(2) + data = base64.b64decode(data) + res = BytesIO(data) + elif isinstance(path, bytes): + res = BytesIO(path) + return res + + +def load_image(image: Union[str, bytes, 'Image.Image']) -> 'Image.Image': + image = load_mm_file(image) + if isinstance(image, BytesIO): + image = Image.open(image) + if image.mode != 'RGB': + image = image.convert('RGB') + return image \ No newline at end of file From b7eb2e144d810685eb259cf12623ace89758378f Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 15:38:39 +0800 Subject: [PATCH 04/13] wip --- src/twinkle/template/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 3751992c..730472cd 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -360,10 +360,10 @@ def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional['torch.Tensor']: """Get position_ids. Override for models with special position encoding.""" return None - def _post_encode(self, model: torch.nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + def _post_encode(self, model: 'torch.nn.Module', inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs - def pre_forward_hook(self, model: torch.nn.Module, args, kwargs): + def pre_forward_hook(self, model: 'torch.nn.Module', args, kwargs): device = next(model.parameters()).device old_kwargs = to_device(kwargs, device) kwargs = to_device(self._post_encode(model, old_kwargs), device) From ea0595e5a7130ac8a9a1b7b9bf8c45e1e4a97ac8 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 15:39:37 +0800 Subject: [PATCH 05/13] wip --- cookbook/mm/fsdp2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py index bb3300b6..b4ca4905 100644 --- a/cookbook/mm/fsdp2.py +++ b/cookbook/mm/fsdp2.py @@ -10,9 +10,9 @@ from twinkle.preprocessor import SelfCognitionProcessor, Preprocessor # Construct a device_mesh, fsdp=4, dp=2 -device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2) +# device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2) # use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) +# twinkle.initialize(mode='local', global_device_mesh=device_mesh) logger = get_logger() From bd0ee8ba25d3f9b5db47ca5f00d0e99b73f12de0 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 17:34:00 +0800 Subject: [PATCH 06/13] wip --- cookbook/mm/fsdp2.py | 6 ++++-- src/twinkle/processor/base.py | 8 ++++++-- src/twinkle/template/base.py | 5 +++++ src/twinkle/template/qwen3_vl.py | 30 +++++++++++++----------------- src/twinkle/template/utils.py | 4 ++-- src/twinkle/utils/vision_tools.py | 1 + 6 files changed, 31 insertions(+), 23 deletions(-) diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py index b4ca4905..a538553b 100644 --- a/cookbook/mm/fsdp2.py +++ b/cookbook/mm/fsdp2.py @@ -32,7 +32,7 @@ def train(): # 2000 samples dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Qwen3VLTemplate', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format dataset.map(LatexOCRProcessor) # Encode dataset @@ -40,7 +40,8 @@ def train(): # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration + model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', model_cls=Qwen3_5ForConditionalGeneration) model.model._no_split_modules = {'Qwen3_5DecoderLayer'} lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') @@ -49,6 +50,7 @@ def train(): # Comment this to use full-parameter training model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) # Add Optimizer for lora `default` + model.set_template('Qwen3VLTemplate', model_id='ms://Qwen/Qwen3.5-4B') model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) # Add LRScheduler for lora `default` model.set_lr_scheduler( diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index ca021ff4..e9c86d95 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -260,7 +260,8 @@ def _create_4d_attention_mask(attention_mask): @staticmethod def _get_packed_seq_params(position_ids): - assert position_ids.shape[0] == 1 + if position_ids.shape[0] > 1: + position_ids = position_ids[:1] position_ids_f = position_ids.flatten() indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) @@ -305,7 +306,7 @@ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeat results = [] for _input in inputs: output = {} - _keys = ['input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask'] + _keys = ['input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask', 'pixel_values', 'image_grid_thw'] for key in list(_input.keys()): if key in _keys: output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key] @@ -361,6 +362,9 @@ def _collate_macro_batch(self, inputs: List[InputFeature]) -> InputFeature: for field, values in vlm_fields.items(): if values: + if values[0].dim() == 1: + # image_thw may be squeezed + values = [value.unsqueeze(0) for value in values] result[field] = torch.cat(values, dim=0) return result diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 730472cd..1bf0b34a 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -115,6 +115,11 @@ def _test_support_assistant_tokens_mask(self): self._template_support_assistant_tokens_mask = False def preprocess_image(self, image: ImageInput) -> 'Image.Image': + if isinstance(image, dict): + if image.get('path'): + image = image['path'] + else: + image = image['bytes'] return load_image(image) def preprocess_video(self, video: VideoInput) -> List['Image.Image']: diff --git a/src/twinkle/template/qwen3_vl.py b/src/twinkle/template/qwen3_vl.py index b679fb80..3a8adfe4 100644 --- a/src/twinkle/template/qwen3_vl.py +++ b/src/twinkle/template/qwen3_vl.py @@ -43,24 +43,20 @@ def merge_size(self) -> int: return self._merge_size or 2 def preprocess_image(self, image: ImageInput) -> Image.Image: - try: - requires('qwen_vl_utils') - from qwen_vl_utils.vision_process import fetch_image - image = load_image(image) - if isinstance(image, str): - image_input = {'image': image} - elif isinstance(image, Image.Image): - image_input = {'image': image} - else: - # Fallback to base class for tensor inputs - return super().preprocess_image(image) - - # Use qwen_vl_utils with correct patch_size - return fetch_image(image_input, image_patch_size=self.patch_size) - - except ImportError: + requires('qwen_vl_utils') + from qwen_vl_utils.vision_process import fetch_image + image = super().preprocess_image(image) + if isinstance(image, str): + image_input = {'image': image} + elif isinstance(image, Image.Image): + image_input = {'image': image} + else: + # Fallback to base class for tensor inputs return super().preprocess_image(image) + # Use qwen_vl_utils with correct patch_size + return fetch_image(image_input, image_patch_size=self.patch_size) + def preprocess_video(self, video: VideoInput) -> Union[List[Image.Image], torch.Tensor]: try: from qwen_vl_utils.vision_process import fetch_video @@ -90,7 +86,7 @@ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: inputs_embeds = base_model.model.embed_tokens(input_ids) else: inputs_embeds = base_model.model.language_model.embed_tokens(input_ids) - inputs_embeds = get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config) + inputs_embeds = get_inputs_embeds_hf(inputs_embeds, inputs, base_model.model.visual, self.processor, model.config) return {'inputs_embeds': inputs_embeds} def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional[torch.Tensor]: diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index a5416542..e2ab35d6 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -104,14 +104,14 @@ def tokenize_with_assistant_labels( assistant_count += 1 _dummy_messages.append(msg) - encoded = encode_func(trajectory, ) + encoded = encode_func(trajectory) full_ids = encoded.pop('input_ids') if isinstance(full_ids, torch.Tensor): full_ids = full_ids.tolist()[0] _dummy_trajectory = copy(trajectory) _dummy_trajectory['messages'] = _dummy_messages - template_ids = encode_func(_dummy_trajectory, ) + template_ids = encode_func(_dummy_trajectory) template_ids = template_ids['input_ids'] if isinstance(template_ids, torch.Tensor): template_ids = template_ids.tolist()[0] diff --git a/src/twinkle/utils/vision_tools.py b/src/twinkle/utils/vision_tools.py index 607d34a8..8469cbc5 100644 --- a/src/twinkle/utils/vision_tools.py +++ b/src/twinkle/utils/vision_tools.py @@ -48,6 +48,7 @@ def load_mm_file(path: Union[str, bytes, _T]) -> Union[BytesIO, _T]: def load_image(image: Union[str, bytes, 'Image.Image']) -> 'Image.Image': image = load_mm_file(image) if isinstance(image, BytesIO): + from PIL import Image image = Image.open(image) if image.mode != 'RGB': image = image.convert('RGB') From 35cf9c157e941c3a1c2304c1acd93c19061dfd2a Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 19:14:02 +0800 Subject: [PATCH 07/13] fix packing --- cookbook/mm/fsdp2.py | 9 +++++---- src/twinkle/processor/base.py | 4 ++++ src/twinkle/template/qwen3_vl.py | 27 ++++++++++----------------- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py index a538553b..6758b603 100644 --- a/cookbook/mm/fsdp2.py +++ b/cookbook/mm/fsdp2.py @@ -5,7 +5,7 @@ from twinkle import DeviceMesh, get_device_placement, get_logger from twinkle.data_format import Trajectory, Message from twinkle.dataloader import DataLoader -from twinkle.dataset import LazyDataset, DatasetMeta +from twinkle.dataset import LazyDataset, PackingDataset, DatasetMeta from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor, Preprocessor @@ -30,15 +30,16 @@ def __call__(self, row) -> Trajectory: def train(): # 2000 samples - dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000))) + dataset = PackingDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000))) # Set template to prepare encoding - dataset.set_template('Qwen3VLTemplate', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Qwen3VLTemplate', model_id='ms://Qwen/Qwen3.5-4B', max_length=1024) # Preprocess the dataset to standard format dataset.map(LatexOCRProcessor) # Encode dataset dataset.encode() + dataset.pack_dataset() # Global batch size = 8, for GPUs, so 1 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=8) + dataloader = DataLoader(dataset=dataset, batch_size=2) # Use a TransformersModel from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', model_cls=Qwen3_5ForConditionalGeneration) diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index e9c86d95..5a90d39b 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -94,6 +94,10 @@ def to_tensor(_input): value = torch.from_numpy(value) elif isinstance(value, list) and isinstance(value[0], (int, float, np.number)): value = torch.tensor(value) + elif key in self.VLM_CONCAT_FIELDS: + if not isinstance(value[0], torch.Tensor): + value = [torch.tensor(v) for v in value] + value = torch.cat(value, dim=0) if isinstance(value, torch.Tensor): value = value.to(Platform.get_local_device()) if value.dim() == 1: diff --git a/src/twinkle/template/qwen3_vl.py b/src/twinkle/template/qwen3_vl.py index 3a8adfe4..366b66a2 100644 --- a/src/twinkle/template/qwen3_vl.py +++ b/src/twinkle/template/qwen3_vl.py @@ -18,9 +18,7 @@ class Qwen3VLTemplate(Template): """ def __init__(self, *args, **kwargs): - # TODO untested code super().__init__(*args, **kwargs) - # Cache processor config for preprocessing self._patch_size: Optional[int] = None self._merge_size: Optional[int] = None self._init_vision_config() @@ -58,21 +56,16 @@ def preprocess_image(self, image: ImageInput) -> Image.Image: return fetch_image(image_input, image_patch_size=self.patch_size) def preprocess_video(self, video: VideoInput) -> Union[List[Image.Image], torch.Tensor]: - try: - from qwen_vl_utils.vision_process import fetch_video - - if isinstance(video, str): - # Use qwen_vl_utils for video loading - video_input = {'video': video} - result = fetch_video(video_input, image_patch_size=self.patch_size, return_video_sample_fps=False) - return result - elif isinstance(video, list): - # List of images - preprocess each frame - return [self.preprocess_image(frame) for frame in video] - else: - return super().preprocess_video(video) - - except ImportError: + requires('qwen_vl_utils') + from qwen_vl_utils.vision_process import fetch_video + + if isinstance(video, str): + video_input = {'video': video} + result = fetch_video(video_input, image_patch_size=self.patch_size, return_video_sample_fps=False) + return result + elif isinstance(video, list): + return [self.preprocess_image(frame) for frame in video] + else: return super().preprocess_video(video) def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: From a69c36c2ec055a0712149b6c0a5be8cf144ad846 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 19:50:57 +0800 Subject: [PATCH 08/13] fix --- cookbook/mm/fsdp2.py | 33 ++++++++++++------- .../transformers/multi_lora_transformers.py | 3 +- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py index 6758b603..22c858e0 100644 --- a/cookbook/mm/fsdp2.py +++ b/cookbook/mm/fsdp2.py @@ -5,14 +5,14 @@ from twinkle import DeviceMesh, get_device_placement, get_logger from twinkle.data_format import Trajectory, Message from twinkle.dataloader import DataLoader -from twinkle.dataset import LazyDataset, PackingDataset, DatasetMeta +from twinkle.dataset import LazyDataset, DatasetMeta from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor, Preprocessor +from twinkle.preprocessor import Preprocessor -# Construct a device_mesh, fsdp=4, dp=2 -# device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2) +# Construct a device_mesh, fsdp=2 +device_mesh = DeviceMesh.from_sizes(fsdp_size=2) # use torchrun mode -# twinkle.initialize(mode='local', global_device_mesh=device_mesh) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) logger = get_logger() @@ -28,18 +28,31 @@ def __call__(self, row) -> Trajectory: ) +def eval(model): + # 100 Samples + dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(100))) + dataset.set_template('Qwen3VLTemplate', model_id='ms://Qwen/Qwen3.5-4B') + dataset.map(LatexOCRProcessor) + dataset.encode() + dataloader = DataLoader(dataset=dataset, batch_size=8) + for step, batch in tqdm(enumerate(dataloader)): + model.forward_only(inputs=batch) + model.calculate_loss() + metrics = model.calculate_metric(is_training=False) + return metrics + + def train(): # 2000 samples - dataset = PackingDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000))) + dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000))) # Set template to prepare encoding dataset.set_template('Qwen3VLTemplate', model_id='ms://Qwen/Qwen3.5-4B', max_length=1024) # Preprocess the dataset to standard format dataset.map(LatexOCRProcessor) # Encode dataset dataset.encode() - dataset.pack_dataset() - # Global batch size = 8, for GPUs, so 1 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=2) + # Global batch size = 4, for GPUs, so 2 sample per GPU + dataloader = DataLoader(dataset=dataset, batch_size=4) # Use a TransformersModel from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B', model_cls=Qwen3_5ForConditionalGeneration) @@ -61,8 +74,6 @@ def train(): logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') loss_metric = 99.0 - # lora: 18G * 4 - # full: 50G * 4 for step, batch in enumerate(dataloader): # Do forward and backward model.forward_backward(inputs=batch) diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index c638d06e..fde05bcb 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -68,13 +68,14 @@ def register_global_mm_forward_hook(self): def forward_hook(model: torch.nn.Module, args, kwargs): active_adapter = model.active_adapters[0] + active_adapter = self.multi_adapter.find_lora(active_adapter).tenant_adapter_name optimizer_group = self.optimizer_group[active_adapter] template = optimizer_group.template assert template is not None return template.pre_forward_hook(model, args, kwargs) model = self.strategy.unwrap_model(self.model) - return model.register_forward_hook(forward_hook) + return model.register_forward_pre_hook(forward_hook, with_kwargs=True) def register_mm_forward_hook(self, optimizer_group: OptimizerGroup): pass From 274442eba55de30d913c8ef15d8bfddf28ea6b27 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 19:52:54 +0800 Subject: [PATCH 09/13] fix --- src/twinkle/hub/hub.py | 3 ++- src/twinkle/processor/base.py | 5 ++++- src/twinkle/template/base.py | 9 +++------ src/twinkle/template/qwen3_vl.py | 4 +++- src/twinkle/template/utils.py | 21 +++++++++++---------- src/twinkle/utils/__init__.py | 2 +- src/twinkle/utils/vision_tools.py | 7 +++---- 7 files changed, 27 insertions(+), 24 deletions(-) diff --git a/src/twinkle/hub/hub.py b/src/twinkle/hub/hub.py index 701505e0..6e1653e1 100644 --- a/src/twinkle/hub/hub.py +++ b/src/twinkle/hub/hub.py @@ -165,7 +165,8 @@ def load_dataset(cls, subset_name: str, split: str, streaming: bool = False, - revision: Optional[str] = None, **kwargs): + revision: Optional[str] = None, + **kwargs): """Load a dataset from the repo Args: diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 5a90d39b..e7bce2d3 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -310,7 +310,10 @@ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeat results = [] for _input in inputs: output = {} - _keys = ['input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask', 'pixel_values', 'image_grid_thw'] + _keys = [ + 'input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask', + 'pixel_values', 'image_grid_thw' + ] for key in list(_input.keys()): if key in _keys: output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key] diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 1bf0b34a..9e72d683 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -1,17 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import inspect - import numpy as np import os from collections.abc import Mapping -from copy import deepcopy, copy +from copy import copy, deepcopy from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union from twinkle.data_format import InputFeature, Message, Trajectory from twinkle.hub import HubOperation +from twinkle.utils import load_image, to_device from .utils import tokenize_with_assistant_labels, transfer_to_standard_message -from twinkle.utils import to_device -from twinkle.utils import load_image if TYPE_CHECKING: import torch @@ -235,8 +233,7 @@ def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: assert len(message['audios']) == content.count(self.audio_placeholder) new_messages.append( transfer_to_standard_message(message, self.image_placeholder, self.video_placeholder, - self.audio_placeholder, - self.is_mm)) + self.audio_placeholder, self.is_mm)) trajectory['messages'] = new_messages return [trajectory] diff --git a/src/twinkle/template/qwen3_vl.py b/src/twinkle/template/qwen3_vl.py index 366b66a2..2cfe2019 100644 --- a/src/twinkle/template/qwen3_vl.py +++ b/src/twinkle/template/qwen3_vl.py @@ -1,6 +1,7 @@ import torch from PIL import Image from typing import Any, Dict, List, Optional, Union + from twinkle import remote_class, requires from twinkle.template import Template from twinkle.template.base import ImageInput, VideoInput @@ -79,7 +80,8 @@ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: inputs_embeds = base_model.model.embed_tokens(input_ids) else: inputs_embeds = base_model.model.language_model.embed_tokens(input_ids) - inputs_embeds = get_inputs_embeds_hf(inputs_embeds, inputs, base_model.model.visual, self.processor, model.config) + inputs_embeds = get_inputs_embeds_hf(inputs_embeds, inputs, base_model.model.visual, self.processor, + model.config) return {'inputs_embeds': inputs_embeds} def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional[torch.Tensor]: diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py index e2ab35d6..ccdcfc91 100644 --- a/src/twinkle/template/utils.py +++ b/src/twinkle/template/utils.py @@ -172,13 +172,13 @@ def _load_image(img: Any) -> Optional[Any]: def _transfer_single_message( - content: str, - image_placeholder: str, - video_placeholder: str, - audio_placeholder: str, - images: list | None = None, - videos: list | None = None, - audios: list | None = None, + content: str, + image_placeholder: str, + video_placeholder: str, + audio_placeholder: str, + images: list | None = None, + videos: list | None = None, + audios: list | None = None, ) -> list[dict]: if not content: return [] @@ -228,8 +228,9 @@ def _transfer_single_message( def transfer_to_standard_message(message: Message, image_placeholder, video_placeholder, audio_placeholder, is_mm): if is_mm: - new_content = _transfer_single_message(message['content'], image_placeholder, video_placeholder, audio_placeholder, - message.get('images'), message.get('videos'), message.get('audios')) + new_content = _transfer_single_message(message['content'], image_placeholder, video_placeholder, + audio_placeholder, message.get('images'), message.get('videos'), + message.get('audios')) else: new_content = message['content'] @@ -279,7 +280,7 @@ def get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config): image_embeds = mixed_embeds video_embeds = None else: - merge_length = processor.image_processor.merge_size ** 2 + merge_length = processor.image_processor.merge_size**2 image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() image_embeds = mixed_embeds[:image_tokens] video_embeds = mixed_embeds[image_tokens:] diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py index 6bcd36d4..1b018773 100644 --- a/src/twinkle/utils/__init__.py +++ b/src/twinkle/utils/__init__.py @@ -14,4 +14,4 @@ from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert from .unsafe import check_unsafe, trust_remote_code from .utils import copy_files_by_pattern, deep_getattr -from .vision_tools import load_mm_file, load_image +from .vision_tools import load_image, load_mm_file diff --git a/src/twinkle/utils/vision_tools.py b/src/twinkle/utils/vision_tools.py index 8469cbc5..5638fca0 100644 --- a/src/twinkle/utils/vision_tools.py +++ b/src/twinkle/utils/vision_tools.py @@ -2,11 +2,10 @@ import base64 import os import re -from io import BytesIO -from typing import Union, TypeVar, TYPE_CHECKING - import requests +from io import BytesIO from requests.adapters import HTTPAdapter +from typing import TYPE_CHECKING, TypeVar, Union from urllib3.util.retry import Retry if TYPE_CHECKING: @@ -52,4 +51,4 @@ def load_image(image: Union[str, bytes, 'Image.Image']) -> 'Image.Image': image = Image.open(image) if image.mode != 'RGB': image = image.convert('RGB') - return image \ No newline at end of file + return image From 94a8721cda5de6ff3643cdd667c5cb82edd2b6ed Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 19:54:21 +0800 Subject: [PATCH 10/13] fix --- src/twinkle/model/transformers/multi_lora_transformers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index fde05bcb..f4bccc53 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -1,7 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from peft import LoraConfig, PeftConfig, PeftModel, load_peft_weights -from sympy.printing.pytorch import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel @@ -66,7 +65,7 @@ def _check_adapter_valid(self, adapter_name: str): def register_global_mm_forward_hook(self): - def forward_hook(model: torch.nn.Module, args, kwargs): + def forward_hook(model, args, kwargs): active_adapter = model.active_adapters[0] active_adapter = self.multi_adapter.find_lora(active_adapter).tenant_adapter_name optimizer_group = self.optimizer_group[active_adapter] From d96ec37cb75cde72db8642bf15a2cf7aa49f690b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 20:05:05 +0800 Subject: [PATCH 11/13] fix --- cookbook/mm/fsdp2.py | 6 ++-- src/twinkle/template/__init__.py | 2 +- src/twinkle/template/base.py | 6 ++-- .../template/{qwen3_vl.py => qwen3_5_vl.py} | 33 +------------------ tests/dataloader/test_multimodal.py | 4 +-- tests/dataset/test_multimodal.py | 10 +++--- tests/sampler/align_swift.py | 6 ++-- 7 files changed, 17 insertions(+), 50 deletions(-) rename src/twinkle/template/{qwen3_vl.py => qwen3_5_vl.py} (73%) diff --git a/cookbook/mm/fsdp2.py b/cookbook/mm/fsdp2.py index 22c858e0..a93cd705 100644 --- a/cookbook/mm/fsdp2.py +++ b/cookbook/mm/fsdp2.py @@ -31,7 +31,7 @@ def __call__(self, row) -> Trajectory: def eval(model): # 100 Samples dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(100))) - dataset.set_template('Qwen3VLTemplate', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(LatexOCRProcessor) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -46,7 +46,7 @@ def train(): # 2000 samples dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(2000))) # Set template to prepare encoding - dataset.set_template('Qwen3VLTemplate', model_id='ms://Qwen/Qwen3.5-4B', max_length=1024) + dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B', max_length=1024) # Preprocess the dataset to standard format dataset.map(LatexOCRProcessor) # Encode dataset @@ -64,7 +64,7 @@ def train(): # Comment this to use full-parameter training model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) # Add Optimizer for lora `default` - model.set_template('Qwen3VLTemplate', model_id='ms://Qwen/Qwen3.5-4B') + model.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) # Add LRScheduler for lora `default` model.set_lr_scheduler( diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 346ca37e..324ce7ac 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import Template -from .qwen3_vl import Qwen3VLTemplate +from .qwen3_5_vl import Qwen3_5Template diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 9e72d683..dfa04a58 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -358,14 +358,12 @@ def _get_vision_token_id(self) -> Optional[int]: else: return self.processor.encode(self.image_placeholder) - def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional['torch.Tensor']: - """Get position_ids. Override for models with special position encoding.""" - return None - def _post_encode(self, model: 'torch.nn.Module', inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs def pre_forward_hook(self, model: 'torch.nn.Module', args, kwargs): + if not self.is_mm: + return args, kwargs device = next(model.parameters()).device old_kwargs = to_device(kwargs, device) kwargs = to_device(self._post_encode(model, old_kwargs), device) diff --git a/src/twinkle/template/qwen3_vl.py b/src/twinkle/template/qwen3_5_vl.py similarity index 73% rename from src/twinkle/template/qwen3_vl.py rename to src/twinkle/template/qwen3_5_vl.py index 2cfe2019..b95877a6 100644 --- a/src/twinkle/template/qwen3_vl.py +++ b/src/twinkle/template/qwen3_5_vl.py @@ -6,11 +6,10 @@ from twinkle.template import Template from twinkle.template.base import ImageInput, VideoInput from twinkle.template.utils import get_inputs_embeds_hf -from twinkle.utils import load_image @remote_class() -class Qwen3VLTemplate(Template): +class Qwen3_5Template(Template): """ Processor for Qwen VL series. @@ -83,33 +82,3 @@ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: inputs_embeds = get_inputs_embeds_hf(inputs_embeds, inputs, base_model.model.visual, self.processor, model.config) return {'inputs_embeds': inputs_embeds} - - def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional[torch.Tensor]: - """Get 3D RoPE position_ids for Qwen VL.""" - if self.model is None: - return None - - input_ids = inputs.get('input_ids') - if input_ids is None: - return None - - # Find get_rope_index - base_model = self.model - if hasattr(base_model, 'base_model'): - base_model = base_model.base_model - if hasattr(base_model, 'model'): - base_model = base_model.model - - get_rope_index = getattr(base_model, 'get_rope_index', None) - if get_rope_index is None and hasattr(base_model, 'model'): - get_rope_index = getattr(base_model.model, 'get_rope_index', None) - - if get_rope_index is None: - return None - - try: - position_ids, _ = get_rope_index(input_ids, inputs.get('image_grid_thw'), inputs.get('video_grid_thw'), - inputs.get('attention_mask')) - return position_ids - except Exception: - return None diff --git a/tests/dataloader/test_multimodal.py b/tests/dataloader/test_multimodal.py index 9b4905bb..0031b150 100644 --- a/tests/dataloader/test_multimodal.py +++ b/tests/dataloader/test_multimodal.py @@ -28,9 +28,9 @@ def test_dataloader_multimodal_with_lazy_dataset(self): dataset.map(create_multimodal_messages) try: - dataset.set_template('Qwen3VLTemplate', model_id='Qwen/Qwen2-VL-7B-Instruct') + dataset.set_template('Qwen3_5Template', model_id='Qwen/Qwen2-VL-7B-Instruct') except Exception as e: - pytest.skip(f'Failed to load Qwen3VLTemplate (may need network): {e}') + pytest.skip(f'Failed to load Qwen3_5Template (may need network): {e}') dataset.encode() diff --git a/tests/dataset/test_multimodal.py b/tests/dataset/test_multimodal.py index 31e92239..5fca8f4e 100644 --- a/tests/dataset/test_multimodal.py +++ b/tests/dataset/test_multimodal.py @@ -37,15 +37,15 @@ def test_multimodal_dataset_basic(self): @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download') def test_multimodal_dataset_with_qwen3vl_template(self): - # Use Qwen3VLTemplate + # Use Qwen3_5Template csv_path = str(TEST_DATA_DIR / 'test.csv') dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path)) dataset.map(create_multimodal_messages) try: - dataset.set_template('Qwen3VLTemplate', model_id='Qwen/Qwen3-VL-2B-Instruct') + dataset.set_template('Qwen3_5Template', model_id='Qwen/Qwen3-VL-2B-Instruct') except Exception as e: - pytest.skip(f'Failed to load Qwen3VLTemplate (may need network): {e}') + pytest.skip(f'Failed to load Qwen3_5Template (may need network): {e}') assert dataset.template is not None assert hasattr(dataset.template, 'is_mm') @@ -58,9 +58,9 @@ def test_multimodal_dataset_encode_with_lazy(self): dataset.map(create_multimodal_messages) try: - dataset.set_template('Qwen3VLTemplate', model_id='Qwen/Qwen3-VL-2B-Instruct') + dataset.set_template('Qwen3_5Template', model_id='Qwen/Qwen3-VL-2B-Instruct') except Exception as e: - pytest.skip(f'Failed to load Qwen3VLTemplate (may need network): {e}') + pytest.skip(f'Failed to load Qwen3_5Template (may need network): {e}') try: dataset.encode() diff --git a/tests/sampler/align_swift.py b/tests/sampler/align_swift.py index b9bcec77..568c0563 100644 --- a/tests/sampler/align_swift.py +++ b/tests/sampler/align_swift.py @@ -28,7 +28,7 @@ from twinkle.sampler.torch_sampler import TorchSampler from twinkle.sampler.vllm_sampler import vLLMSampler from twinkle.template import Template -from twinkle.template.qwen3_vl import Qwen3VLTemplate +from twinkle.template.qwen3_vl import Qwen3_5Template # Test models LLM_MODEL_ID = 'Qwen/Qwen2.5-7B-Instruct' @@ -269,7 +269,7 @@ def test_mllm_torch_sampler(): seed_everything(42) from transformers import Qwen3VLForConditionalGeneration sampler = TorchSampler(MLLM_MODEL_ID, model_cls=Qwen3VLForConditionalGeneration) - sampler.set_template(Qwen3VLTemplate, model_id=MLLM_MODEL_ID) + sampler.set_template(Qwen3_5Template, model_id=MLLM_MODEL_ID) trajectory = Trajectory(messages=MLLM_MESSAGES, images=MLLM_IMAGES) sampling_params = SamplingParams(max_tokens=128, temperature=0) @@ -297,7 +297,7 @@ def test_mllm_vllm_sampler(): seed_everything(42) sampler = vLLMSampler(MLLM_MODEL_ID, gpu_memory_utilization=VLLM_GPU_MEM, max_model_len=VLLM_MAX_MODEL_LEN) - sampler.set_template(Qwen3VLTemplate, model_id=MLLM_MODEL_ID) + sampler.set_template(Qwen3_5Template, model_id=MLLM_MODEL_ID) trajectory = Trajectory(messages=MLLM_MESSAGES, images=MLLM_IMAGES) sampling_params = SamplingParams(max_tokens=128, temperature=0) From 53d32a48eaeebc2a091ee977de2974b1267f166d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 20:20:34 +0800 Subject: [PATCH 12/13] fix --- cookbook/mm/fsdp2.sh | 1 + cookbook/transformers/fsdp2.py | 7 +++---- cookbook/transformers/sp_fsdp_dense.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) create mode 100644 cookbook/mm/fsdp2.sh diff --git a/cookbook/mm/fsdp2.sh b/cookbook/mm/fsdp2.sh new file mode 100644 index 00000000..46e9f27f --- /dev/null +++ b/cookbook/mm/fsdp2.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 fsdp2.py diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 5624495d..ca37d724 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -20,7 +20,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -35,7 +35,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -43,8 +43,7 @@ def train(): # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') - model.model._no_split_modules = {'Qwen3_5DecoderLayer'} + model = TransformersModel(model_id='ms://Qwen/Qwen3-4B') lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 868b61c0..da6e2d28 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -10,7 +10,7 @@ from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() -MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +MODEL_ID = 'ms://Qwen/Qwen3-4B' DATASETS = 'ms://swift/self-cognition' device_group = [DeviceGroup( From 93576dc2f303a5c1a31da2397cbf548364567c0d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Wed, 4 Mar 2026 20:23:07 +0800 Subject: [PATCH 13/13] fix --- tests/sampler/align_swift.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/sampler/align_swift.py b/tests/sampler/align_swift.py index 568c0563..dc33ff36 100644 --- a/tests/sampler/align_swift.py +++ b/tests/sampler/align_swift.py @@ -17,7 +17,6 @@ import gc import os -import sys import torch from swift.infer_engine import RequestConfig, TransformersEngine, VllmEngine from swift.utils import seed_everything @@ -27,8 +26,7 @@ from twinkle.data_format import SamplingParams, Trajectory from twinkle.sampler.torch_sampler import TorchSampler from twinkle.sampler.vllm_sampler import vLLMSampler -from twinkle.template import Template -from twinkle.template.qwen3_vl import Qwen3_5Template +from twinkle.template import Qwen3_5Template, Template # Test models LLM_MODEL_ID = 'Qwen/Qwen2.5-7B-Instruct' @@ -182,7 +180,6 @@ def test_llm_vllm_sampler_ray(): from peft import LoraConfig from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger - from twinkle.checkpoint_engine import CheckpointEngineManager from twinkle.model import TransformersModel from twinkle.processor import InputProcessor