From 905efd5d1d929ccbf5d9d8e796b827ffe18dbae6 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 18 Dec 2025 22:43:28 +0800 Subject: [PATCH 001/692] wip --- .../kernel/csrc => cookbook}/placeholder | 0 src/ms-components/kernel/__init__.py | 8 ++ src/ms-components/model/model_meta.py | 11 ++ src/ms-components/platform/__init__.py | 0 src/ms-components/utils/__init__.py | 1 + src/ms-components/utils/framework.py | 100 ++++++++++++++++++ 6 files changed, 120 insertions(+) rename {src/ms-components/kernel/csrc => cookbook}/placeholder (100%) create mode 100644 src/ms-components/model/model_meta.py create mode 100644 src/ms-components/platform/__init__.py create mode 100644 src/ms-components/utils/__init__.py create mode 100644 src/ms-components/utils/framework.py diff --git a/src/ms-components/kernel/csrc/placeholder b/cookbook/placeholder similarity index 100% rename from src/ms-components/kernel/csrc/placeholder rename to cookbook/placeholder diff --git a/src/ms-components/kernel/__init__.py b/src/ms-components/kernel/__init__.py index e69de29b..c8bf449e 100644 --- a/src/ms-components/kernel/__init__.py +++ b/src/ms-components/kernel/__init__.py @@ -0,0 +1,8 @@ +from typing import Union, Callable, Any, List + + +def apply_kernel(module: Any, + kernel: Union[str, 'torch.nn.Module', Callable[[*Any], Any]], + target_modules: Union[str, List[str]]) -> None: + + if module.__class__ \ No newline at end of file diff --git a/src/ms-components/model/model_meta.py b/src/ms-components/model/model_meta.py new file mode 100644 index 00000000..6a3259ef --- /dev/null +++ b/src/ms-components/model/model_meta.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class ModelMeta: + + library: Literal['transformers', 'megatron'] + + framework: Literal['torch'] = 'torch' + diff --git a/src/ms-components/platform/__init__.py b/src/ms-components/platform/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ms-components/utils/__init__.py b/src/ms-components/utils/__init__.py new file mode 100644 index 00000000..a463261d --- /dev/null +++ b/src/ms-components/utils/__init__.py @@ -0,0 +1 @@ +from .framework import get_library \ No newline at end of file diff --git a/src/ms-components/utils/framework.py b/src/ms-components/utils/framework.py new file mode 100644 index 00000000..5fd17a71 --- /dev/null +++ b/src/ms-components/utils/framework.py @@ -0,0 +1,100 @@ +from abc import ABC, abstractmethod +from typing import Literal +from functools import lru_cache + + +class Framework(ABC): + + @staticmethod + @abstractmethod + def get_library(module) -> str: + """Get the library name of the input module""" + ... + + @staticmethod + @abstractmethod + def get_current_device(): + """Set the current device""" + ... + + @staticmethod + @abstractmethod + def get_device(): + """Get the device type""" + ... + + @staticmethod + @abstractmethod + def set_device(idx: int): + """Set the current device""" + ... + + + +class Torch(Framework): + + @staticmethod + def get_library(module) -> Literal['transformers', 'megatron', 'other']: + module_path = type(module).__module__ + if "transformers" in module_path: + return "transformers" + elif "megatron" in module_path: + return "megatron" + else: + return "other" + + @lru_cache + @staticmethod + def is_torch_npu_available(check_device=False) -> bool: + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if not _torch_available or importlib.util.find_spec("torch_npu") is None: + return False + + import torch + import torch_npu # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + return hasattr(torch, "npu") and torch.npu.is_available() + + @staticmethod + @lru_cache + def get_current_device(): + import torch + if torch.cuda.is_available(): + return torch.cuda.current_device() + elif + + + @staticmethod + def get_device(): + pass + + @staticmethod + def set_device(idx: int): + pass + + +def get_framework(module) -> Literal['torch', 'other']: + if "torch" in type(module).__module__ or hasattr(module, "parameters"): + return "torch" + return 'other' + + +def get_library(module) -> Literal['transformers', 'megatron', 'other']: + """Get The library of one module + + Args: + module: A torch.nn.Module instance + + Returns: + A string representing the library, supports `transformers` or `megatron` or `other` + """ + if get_framework(module) == 'torch': + return Torch.get_library(module) + return 'other' From 6bbaebea35b47d11b27296b24169fe200cd43cad Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 19 Dec 2025 11:47:51 +0800 Subject: [PATCH 002/692] wip --- src/ms-components/kernel/__init__.py | 8 -- src/ms-components/model/model_meta.py | 11 -- src/ms-components/platform/__init__.py | 0 src/ms-components/utils/__init__.py | 1 - src/ms-components/utils/framework.py | 100 -------------- src/twinkle/__init__.py | 22 +++ src/twinkle/kernel/__init__.py | 37 +++++ src/twinkle/utils/__init__.py | 3 + src/twinkle/utils/framework.py | 180 +++++++++++++++++++++++++ src/twinkle/utils/import_utils.py | 91 +++++++++++++ src/twinkle/version.py | 5 + 11 files changed, 338 insertions(+), 120 deletions(-) delete mode 100644 src/ms-components/kernel/__init__.py delete mode 100644 src/ms-components/model/model_meta.py delete mode 100644 src/ms-components/platform/__init__.py delete mode 100644 src/ms-components/utils/__init__.py delete mode 100644 src/ms-components/utils/framework.py create mode 100644 src/twinkle/utils/__init__.py create mode 100644 src/twinkle/utils/framework.py create mode 100644 src/twinkle/utils/import_utils.py create mode 100644 src/twinkle/version.py diff --git a/src/ms-components/kernel/__init__.py b/src/ms-components/kernel/__init__.py deleted file mode 100644 index c8bf449e..00000000 --- a/src/ms-components/kernel/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Union, Callable, Any, List - - -def apply_kernel(module: Any, - kernel: Union[str, 'torch.nn.Module', Callable[[*Any], Any]], - target_modules: Union[str, List[str]]) -> None: - - if module.__class__ \ No newline at end of file diff --git a/src/ms-components/model/model_meta.py b/src/ms-components/model/model_meta.py deleted file mode 100644 index 6a3259ef..00000000 --- a/src/ms-components/model/model_meta.py +++ /dev/null @@ -1,11 +0,0 @@ -from dataclasses import dataclass -from typing import Literal - - -@dataclass -class ModelMeta: - - library: Literal['transformers', 'megatron'] - - framework: Literal['torch'] = 'torch' - diff --git a/src/ms-components/platform/__init__.py b/src/ms-components/platform/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/ms-components/utils/__init__.py b/src/ms-components/utils/__init__.py deleted file mode 100644 index a463261d..00000000 --- a/src/ms-components/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .framework import get_library \ No newline at end of file diff --git a/src/ms-components/utils/framework.py b/src/ms-components/utils/framework.py deleted file mode 100644 index 5fd17a71..00000000 --- a/src/ms-components/utils/framework.py +++ /dev/null @@ -1,100 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Literal -from functools import lru_cache - - -class Framework(ABC): - - @staticmethod - @abstractmethod - def get_library(module) -> str: - """Get the library name of the input module""" - ... - - @staticmethod - @abstractmethod - def get_current_device(): - """Set the current device""" - ... - - @staticmethod - @abstractmethod - def get_device(): - """Get the device type""" - ... - - @staticmethod - @abstractmethod - def set_device(idx: int): - """Set the current device""" - ... - - - -class Torch(Framework): - - @staticmethod - def get_library(module) -> Literal['transformers', 'megatron', 'other']: - module_path = type(module).__module__ - if "transformers" in module_path: - return "transformers" - elif "megatron" in module_path: - return "megatron" - else: - return "other" - - @lru_cache - @staticmethod - def is_torch_npu_available(check_device=False) -> bool: - "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" - if not _torch_available or importlib.util.find_spec("torch_npu") is None: - return False - - import torch - import torch_npu # noqa: F401 - - if check_device: - try: - # Will raise a RuntimeError if no NPU is found - _ = torch.npu.device_count() - return torch.npu.is_available() - except RuntimeError: - return False - return hasattr(torch, "npu") and torch.npu.is_available() - - @staticmethod - @lru_cache - def get_current_device(): - import torch - if torch.cuda.is_available(): - return torch.cuda.current_device() - elif - - - @staticmethod - def get_device(): - pass - - @staticmethod - def set_device(idx: int): - pass - - -def get_framework(module) -> Literal['torch', 'other']: - if "torch" in type(module).__module__ or hasattr(module, "parameters"): - return "torch" - return 'other' - - -def get_library(module) -> Literal['transformers', 'megatron', 'other']: - """Get The library of one module - - Args: - module: A torch.nn.Module instance - - Returns: - A string representing the library, supports `transformers` or `megatron` or `other` - """ - if get_framework(module) == 'torch': - return Torch.get_library(module) - return 'other' diff --git a/src/twinkle/__init__.py b/src/twinkle/__init__.py index e69de29b..4c9dee08 100644 --- a/src/twinkle/__init__.py +++ b/src/twinkle/__init__.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING +from .utils.import_utils import _LazyModule # noqa + +if TYPE_CHECKING: + from .version import __version__, __release_datetime__ + from .utils import framework, torch, requires, exists + +else: + _import_structure = { + 'version': ['__release_datetime__', '__version__'], + 'utils': ['framework', 'torch', 'requires'], + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, # noqa + extra_objects={}, + ) diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index e69de29b..81448919 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -0,0 +1,37 @@ +from typing import Union, Callable, Any, List, Optional, Literal +from ..utils import torch as torch_util, framework as framework_util +from ..utils import exists + + +def apply_kernel(module: Any, + mode: Literal['train', 'inference', 'compile', None] = 'train', + kernel: Optional[Union[str, Callable[[*Any], Any]]]=None, + target_modules: Union[str, List[str]]=None, + device: Optional[Union[str, Any]] = None, + ) -> Any: + if framework_util.get_framework(module) == 'torch': + if torch_util.get_library(module) == 'transformers': + if exists('kernels'): + from kernels import kernelize, Mode + kernel_mode = Mode.TRAINING + if mode == 'inference': + kernel_mode = Mode.INFERENCE + elif mode == 'compile': + kernel_mode = Mode.TORCH_COMPILE + from kernels import kernelize + return kernelize(module, mode=kernel_mode, device=device) + + assert target_modules is not None and kernel is not None + + + else: + raise NotImplementedError(f'Unsupported applying kernels for: {module.__class__}') + + +def apply_kernel_torch(module: Any, + mode: Literal['train', 'inference', 'compile', None] = 'train', + kernel: Optional[Union[str, Callable[[*Any], Any]]]=None, + target_modules: Union[str, List[str]]=None, + device: Optional[Union[str, Any]] = None,): + + diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py new file mode 100644 index 00000000..90983801 --- /dev/null +++ b/src/twinkle/utils/__init__.py @@ -0,0 +1,3 @@ +from .framework import Torch as torch +from .framework import Framework as framework +from .import_utils import requires, exists diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py new file mode 100644 index 00000000..237161b4 --- /dev/null +++ b/src/twinkle/utils/framework.py @@ -0,0 +1,180 @@ +import importlib +import os +from abc import ABC, abstractmethod +from typing import Literal, Union +from functools import lru_cache + + +class Framework(ABC): + + @staticmethod + @abstractmethod + def get_current_device() -> int: + """Set the current device""" + ... + + @staticmethod + @abstractmethod + def get_device(local_rank) -> str: + """Get the device of the specified rank""" + ... + + @staticmethod + @abstractmethod + def set_device(local_rank: Union[str, int]) -> None: + """Set the current device""" + ... + + @staticmethod + def get_rank() -> int: + """Get the global rank""" + return int(os.getenv('RANK', -1)) + + @staticmethod + def get_local_rank() -> int: + """Get the local rank""" + return int(os.getenv('LOCAL_RANK', -1)) + + @staticmethod + def get_world_size() -> int: + """Get the world size""" + return int(os.getenv('WORLD_SIZE') or os.getenv('_PATCH_WORLD_SIZE') or 1) + + @staticmethod + def get_local_world_size() -> int: + """Get the local world size""" + return int(os.getenv('LOCAL_WORLD_SIZE', None) or os.getenv('LOCAL_SIZE', 1)) + + @staticmethod + def get_nnodes() -> int: + """Get the node count""" + return int(os.getenv('NNODES', 1)) + + @staticmethod + def get_node_rank() -> int: + """Get the current node rank""" + return int(os.getenv('NODE_RANK', 0)) + + @staticmethod + def is_local_master() -> bool: + """Get if current is the local master""" + local_rank = Framework.get_local_rank() + return local_rank in {-1, 0} + + @staticmethod + def is_master() -> bool: + """Get if current is the global master""" + rank = Framework.get_rank() + return rank in {-1, 0} + + @staticmethod + def is_last_rank() -> bool: + """Get if current is the last rank""" + rank = Framework.get_rank() + world_size = Framework.get_world_size() + return rank in {-1, world_size - 1} + + @staticmethod + def get_framework(module) -> Literal['torch', 'other']: + """Get the framework""" + if "torch" in type(module).__module__ or hasattr(module, "parameters"): + return "torch" + return 'other' + + @staticmethod + def get_library(module) -> Literal['transformers', 'megatron', 'other']: + """Get The library of one module + + Args: + module: A torch.nn.Module instance + + Returns: + A string representing the library, supports `transformers` or `megatron` or `other` + """ + if Framework.get_framework(module) == 'torch': + return Torch.get_library(module) + return 'other' + + +class Torch(Framework): + + @staticmethod + def get_library(module) -> Literal['transformers', 'megatron', 'other']: + module_path = type(module).__module__ + if "transformers" in module_path: + return "transformers" + elif "megatron" in module_path: + return "megatron" + else: + return "other" + + @staticmethod + def is_torch_available() -> bool: + """Check if `torch` is installed""" + return importlib.util.find_spec('torch') is not None + + @staticmethod + def is_torch_npu_available() -> bool: + """Check if `torch_npu` is installed""" + return importlib.util.find_spec('torch_npu') is not None + + @staticmethod + def is_gpu_available() -> bool: + "Checks if at least one GPU device is available" + if not Torch.is_torch_available(): + return False + + import torch + if not hasattr(torch, 'cuda'): + return False + + return torch.cuda.is_available() + + @staticmethod + def is_npu_available() -> bool: + "Checks if `torch_npu` is installed and if at least one NPU device is available" + if not Torch.is_torch_available() or not Torch.is_torch_npu_available(): + return False + + import torch + import torch_npu + if not hasattr(torch, 'npu'): + return False + + return torch.npu.is_available() and torch.npu.device_count() > 0 + + @staticmethod + @lru_cache + def get_current_device() -> 'Union[int, str, "torch.device"]': + import torch + if Torch.is_gpu_available(): + return torch.cuda.current_device() + elif Torch.is_npu_available(): + import torch_npu + return torch.npu.current_device() + else: + return 'cpu' + + @staticmethod + def get_device(local_rank) -> str: + if local_rank is None: + local_rank = max(0, Torch.get_local_rank()) + local_rank = str(local_rank) + if Torch.is_gpu_available(): + device = 'cuda:{}'.format(local_rank) + elif Torch.is_npu_available(): + device = 'npu:{}'.format(local_rank) + else: + device = 'cpu' + return device + + @staticmethod + def set_device(local_rank: Union[int, str]) -> None: + import torch + if local_rank is None: + local_rank = max(0, Torch.get_local_rank()) + if Torch.is_gpu_available(): + torch.cuda.set_device(local_rank) + elif Torch.is_npu_available(): + import torch_npu + torch.npu.set_device(local_rank) diff --git a/src/twinkle/utils/import_utils.py b/src/twinkle/utils/import_utils.py new file mode 100644 index 00000000..78b632c3 --- /dev/null +++ b/src/twinkle/utils/import_utils.py @@ -0,0 +1,91 @@ +import importlib +import importlib.util +import importlib.metadata +import os +from itertools import chain +from types import ModuleType +from typing import Any + +from packaging.requirements import Requirement + + +def requires(package: str): + req = Requirement(package) + pkg_name = req.name + + if importlib.util.find_spec(pkg_name) is None: + raise ImportError(f"Required package '{pkg_name}' is not installed") + + if req.specifier: + try: + installed_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + raise ImportError(f"Cannot determine version for '{pkg_name}'") + + if not req.specifier.contains(installed_version): + raise ImportError( + f"Package '{pkg_name}' version {installed_version} " + f"does not satisfy {req.specifier}" + ) + + +def exists(package: str): + try: + requires(package) + return True + except ImportError: + return False + + +class _LazyModule(ModuleType): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + else: + raise AttributeError(f'module {self.__name__} has no attribute {name}') + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str): + return importlib.import_module('.' + module_name, self.__name__) + + def __reduce__(self): + return self.__class__, (self._name, self.__file__, self._import_structure) \ No newline at end of file diff --git a/src/twinkle/version.py b/src/twinkle/version.py new file mode 100644 index 00000000..f36a16bb --- /dev/null +++ b/src/twinkle/version.py @@ -0,0 +1,5 @@ +# Make sure to modify __release_datetime__ to release time when making official release. +__version__ = '0.0.1.dev0' +# default release datetime for branches under active development is set +# to be a time far-far-away-into-the-future +__release_datetime__ = '2099-10-13 08:56:12' From 92a857c9a4f41913ac0d09b5bf462f79728a8893 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 19 Dec 2025 14:41:47 +0800 Subject: [PATCH 003/692] wip --- src/twinkle/hub/__init__.py | 1 + src/twinkle/hub/hub.py | 424 +++++++++++++++++++++++++++++++ src/twinkle/kernel/__init__.py | 45 +++- src/twinkle/uploader/__init__.py | 0 4 files changed, 463 insertions(+), 7 deletions(-) create mode 100644 src/twinkle/hub/__init__.py create mode 100644 src/twinkle/hub/hub.py delete mode 100644 src/twinkle/uploader/__init__.py diff --git a/src/twinkle/hub/__init__.py b/src/twinkle/hub/__init__.py new file mode 100644 index 00000000..87f35537 --- /dev/null +++ b/src/twinkle/hub/__init__.py @@ -0,0 +1 @@ +from .hub import MSHub as ms, HFHub as hf \ No newline at end of file diff --git a/src/twinkle/hub/hub.py b/src/twinkle/hub/hub.py new file mode 100644 index 00000000..114e2a6c --- /dev/null +++ b/src/twinkle/hub/hub.py @@ -0,0 +1,424 @@ +import os +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import List, Literal, Optional, Union + +from requests.exceptions import HTTPError + +from ..utils import requires + + +class HubOperation: + + @classmethod + @contextmanager + def patch_hub(cls): + yield + + @classmethod + def try_login(cls, token: Optional[str] = None) -> bool: + """Try to login to the hub + + Args: + token: The hub token to use + + Returns: + bool: Whether login is successful + """ + raise NotImplementedError + + @classmethod + def create_model_repo(cls, repo_id: str, token: Optional[str] = None, private: bool = False): + """Create a model repo on the hub + + Args: + repo_id: The model id of the hub + token: The hub token to use + private: If is a private repo + """ + raise NotImplementedError + + @classmethod + def push_to_hub(cls, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + private: bool = False, + revision: Optional[str] = 'master', + ignore_patterns: Optional[Union[List[str], str]] = None, + **kwargs): + """Push a model-like folder to the hub + + Args: + repo_id: The repo id + folder_path: The local folder path + path_in_repo: Which remote folder to put the local files in + commit_message: The commit message of git + commit_description: The commit description + token: The hub token + private: Private hub or not + revision: The revision to push to + ignore_patterns: The ignore file patterns + """ + raise NotImplementedError + + @classmethod + def load_dataset(cls, + dataset_id: str, + subset_name: str, + split: str, + streaming: bool = False, + revision: Optional[str] = None): + """Load a dataset from the repo + + Args: + dataset_id: The dataset id + subset_name: The subset name of the dataset + split: The split info + streaming: Streaming mode + revision: The revision of the dataset + + Returns: + The Dataset instance + """ + raise NotImplementedError + + @classmethod + def download_model(cls, + model_id_or_path: Optional[str] = None, + revision: Optional[str] = None, + download_model: bool = True, + ignore_patterns: Optional[List[str]] = None, + **kwargs): + """Download model from the hub + + Args: + model_id_or_path: The model id + revision: The model revision + download_model: Whether downloading bin/safetensors files, this is usually useful when only + using tokenizer + ignore_patterns: Custom ignore pattern + **kwargs: + + Returns: + The local dir + """ + raise NotImplementedError + + +class MSHub(HubOperation): + ms_token = None + + @staticmethod + def create_repo(repo_id: str, + *, + token: Optional[Union[str, bool]] = None, + private: bool = False, + **kwargs) -> 'modelscope.utils.repo_utils.RepoUrl': + """ + Create a new repository on the hub. + + Args: + repo_id: The ID of the repository to create. + token: The authentication token to use. + private: Whether the repository should be private. + **kwargs: Additional arguments. + + Returns: + RepoUrl: The URL of the created repository. + """ + requires('modelscope') + hub_model_id = MSHub.create_model_repo(repo_id, token, private) + from modelscope.utils.repo_utils import RepoUrl + return RepoUrl(url=hub_model_id, ) + + @staticmethod + def upload_folder( + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + revision: Optional[str] = 'master', + ignore_patterns: Optional[Union[List[str], str]] = None, + **kwargs, + ): + requires('modelscope') + from modelscope.utils.repo_utils import CommitInfo + MSHub.push_to_hub(repo_id, folder_path, path_in_repo, commit_message, commit_description, token, True, revision, + ignore_patterns) + return CommitInfo( + commit_url=f'https://www.modelscope.cn/models/{repo_id}/files', + commit_message=commit_message, + commit_description=commit_description, + oid='', + ) + + @classmethod + def try_login(cls, token: Optional[str] = None) -> bool: + requires('modelscope') + from modelscope import HubApi + if token is None: + token = os.environ.get('MODELSCOPE_API_TOKEN') + if token: + api = HubApi() + api.login(token) + return True + return False + + @classmethod + def create_model_repo(cls, repo_id: str, token: Optional[str] = None, private: bool = False) -> str: + requires('modelscope') + from modelscope import HubApi + from modelscope.hub.api import ModelScopeConfig + from modelscope.hub.constants import ModelVisibility + assert repo_id is not None, 'Please enter a valid hub_model_id' + + if not cls.try_login(token): + raise ValueError('Please specify a token by `--hub_token` or `MODELSCOPE_API_TOKEN=xxx`') + cls.ms_token = token + visibility = ModelVisibility.PRIVATE if private else ModelVisibility.PUBLIC + api = HubApi() + if '/' not in repo_id: + user_name = ModelScopeConfig.get_user_info()[0] + assert isinstance(user_name, str) + try: + api.create_model(repo_id, visibility) + except HTTPError: + # The remote repository has been created + pass + + with tempfile.TemporaryDirectory() as temp_cache_dir: + from modelscope.hub.repository import Repository + repo = Repository(temp_cache_dir, repo_id) + cls.add_patterns_to_gitattributes(repo, ['*.safetensors', '*.bin', '*.pt']) + # Add 'runs/' to .gitignore, ignore tensorboard files + cls.add_patterns_to_gitignore(repo, ['runs/', 'images/']) + cls.add_patterns_to_file( + repo, + 'configuration.json', ['{"framework": "pytorch", "task": "text-generation", "allow_remote": true}'], + ignore_push_error=True) + # Add '*.sagemaker' to .gitignore if using SageMaker + if os.environ.get('SM_TRAINING_ENV'): + cls.add_patterns_to_gitignore(repo, ['*.sagemaker-uploading', '*.sagemaker-uploaded'], + 'Add `*.sagemaker` patterns to .gitignore') + return repo_id + + @classmethod + def push_to_hub(cls, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + private: bool = False, + revision: Optional[str] = 'master', + ignore_patterns: Optional[Union[List[str], str]] = None, + **kwargs): + requires('modelscope') + cls.create_model_repo(repo_id, token, private) + from modelscope import push_to_hub + commit_message = commit_message or 'Upload folder using api' + if commit_description: + commit_message = commit_message + '\n' + commit_description + if not os.path.exists(os.path.join(folder_path, 'configuration.json')): + with open(os.path.join(folder_path, 'configuration.json'), 'w', encoding='utf-8') as f: + f.write('{"framework": "pytorch", "task": "text-generation", "allow_remote": true}') + if ignore_patterns: + ignore_patterns = [p for p in ignore_patterns if p != '_*'] + if path_in_repo: + # We don't support part submit for now + path_in_repo = os.path.basename(folder_path) + folder_path = os.path.dirname(folder_path) + ignore_patterns = [] + if revision is None or revision == 'main': + revision = 'master' + push_to_hub( + repo_id, + folder_path, + token or cls.ms_token, + private, + commit_message=commit_message, + ignore_file_pattern=ignore_patterns, + revision=revision, + tag=path_in_repo) + + @classmethod + def load_dataset(cls, + dataset_id: str, + subset_name: str, + split: str, + streaming: bool = False, + revision: Optional[str] = None, + download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists', + token: Optional[str] = None, + **kwargs): + requires('modelscope') + from modelscope import MsDataset + cls.try_login(token) + if revision is None or revision == 'main': + revision = 'master' + load_kwargs = {'trust_remote_code': True} + return MsDataset.load( + dataset_id, + subset_name=subset_name, + split=split, + version=revision, + download_mode=download_mode, # noqa + use_streaming=streaming, + **load_kwargs, + ) + + @classmethod + def download_model(cls, + model_id_or_path: Optional[str] = None, + revision: Optional[str] = None, + ignore_patterns: Optional[List[str]] = None, + token: Optional[str] = None, + **kwargs): + requires('modelscope') + cls.try_login(token) + if revision is None or revision == 'main': + revision = 'master' + from modelscope import snapshot_download + return snapshot_download(model_id_or_path, revision, ignore_patterns=ignore_patterns, **kwargs) + + @staticmethod + def add_patterns_to_file(repo, + file_name: str, + patterns: List[str], + commit_message: Optional[str] = None, + ignore_push_error=False) -> None: + if isinstance(patterns, str): + patterns = [patterns] + if commit_message is None: + commit_message = f'Add `{patterns[0]}` patterns to {file_name}' + + # Get current file content + repo_dir = repo.model_dir + file_path = os.path.join(repo_dir, file_name) + if os.path.exists(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + current_content = f.read() + else: + current_content = '' + # Add the patterns to file + content = current_content + for pattern in patterns: + if pattern not in content: + if len(content) > 0 and not content.endswith('\n'): + content += '\n' + content += f'{pattern}\n' + + # Write the file if it has changed + if content != current_content: + with open(file_path, 'w', encoding='utf-8') as f: + f.write(content) + try: + repo.push(commit_message) + except Exception as e: + if ignore_push_error: + pass + else: + raise e + + @staticmethod + def add_patterns_to_gitignore(repo, patterns: List[str], commit_message: Optional[str] = None) -> None: + MSHub.add_patterns_to_file(repo, '.gitignore', patterns, commit_message, ignore_push_error=True) + + @staticmethod + def add_patterns_to_gitattributes(repo, patterns: List[str], commit_message: Optional[str] = None) -> None: + new_patterns = [] + suffix = 'filter=lfs diff=lfs merge=lfs -text' + for pattern in patterns: + if suffix not in pattern: + pattern = f'{pattern} {suffix}' + new_patterns.append(pattern) + file_name = '.gitattributes' + if commit_message is None: + commit_message = f'Add `{patterns[0]}` patterns to {file_name}' + MSHub.add_patterns_to_file(repo, file_name, new_patterns, commit_message, ignore_push_error=True) + + +class HFHub(HubOperation): + + @classmethod + def try_login(cls, token: Optional[str] = None) -> bool: + pass + + @classmethod + def create_model_repo(cls, repo_id: str, token: Optional[str] = None, private: bool = False) -> str: + requires('huggingface_hub') + from huggingface_hub.hf_api import api + return api.create_repo(repo_id, token=token, private=private) + + @classmethod + def push_to_hub(cls, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + private: bool = False, + revision: Optional[str] = 'master', + ignore_patterns: Optional[Union[List[str], str]] = None, + **kwargs): + requires('huggingface_hub') + from huggingface_hub.hf_api import api + cls.create_model_repo(repo_id, token, private) + if revision is None or revision == 'master': + revision = 'main' + return api.upload_folder( + repo_id=repo_id, + folder_path=folder_path, + path_in_repo=path_in_repo, + commit_message=commit_message, + commit_description=commit_description, + token=token, + revision=revision, + ignore_patterns=ignore_patterns, + **kwargs) + + @classmethod + def load_dataset(cls, + dataset_id: str, + subset_name: str, + split: str, + streaming: bool = False, + revision: Optional[str] = None, + download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists', + num_proc: Optional[int] = None, + **kwargs): + requires('huggingface_hub') + requires('datasets') + from datasets import load_dataset + if revision is None or revision == 'master': + revision = 'main' + return load_dataset( + dataset_id, + name=subset_name, + split=split, + streaming=streaming, + revision=revision, + download_mode=download_mode, + num_proc=num_proc) + + @classmethod + def download_model(cls, + model_id_or_path: Optional[str] = None, + revision: Optional[str] = None, + ignore_patterns: Optional[List[str]] = None, + **kwargs): + if revision is None or revision == 'master': + revision = 'main' + from huggingface_hub import snapshot_download + return snapshot_download( + model_id_or_path, repo_type='model', revision=revision, ignore_patterns=ignore_patterns, **kwargs) diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index 81448919..83fc4660 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -1,11 +1,18 @@ +import re +from types import MethodType from typing import Union, Callable, Any, List, Optional, Literal from ..utils import torch as torch_util, framework as framework_util from ..utils import exists +kernel_mapping = { + +} + + def apply_kernel(module: Any, mode: Literal['train', 'inference', 'compile', None] = 'train', - kernel: Optional[Union[str, Callable[[*Any], Any]]]=None, + kernel: "Optional[Union[str, Callable, 'torch.nn.Module']]"=None, target_modules: Union[str, List[str]]=None, device: Optional[Union[str, Any]] = None, ) -> Any: @@ -22,16 +29,40 @@ def apply_kernel(module: Any, return kernelize(module, mode=kernel_mode, device=device) assert target_modules is not None and kernel is not None - - + return apply_kernel_torch(module, kernel, target_modules=target_modules) else: raise NotImplementedError(f'Unsupported applying kernels for: {module.__class__}') def apply_kernel_torch(module: Any, - mode: Literal['train', 'inference', 'compile', None] = 'train', - kernel: Optional[Union[str, Callable[[*Any], Any]]]=None, - target_modules: Union[str, List[str]]=None, - device: Optional[Union[str, Any]] = None,): + kernel: "Optional[Union[str, Callable, 'torch.nn.Module']]", + target_modules: Union[str, List[str]]): + if kernel in kernel_mapping: + kernel = kernel_mapping[kernel] + + kernel_fn = kernel + import torch + if isinstance(kernel_fn, torch.nn.Module): + kernel_fn = kernel_fn.forward + + if target_modules is None: + raise ValueError(f'Module patching needs a valid `target_modules` parameter,' + f'but current is: {target_modules}') + + if isinstance(target_modules, str): + pattern = re.compile(target_modules) + for name, submodule in module.named_modules(): + if pattern.search(name): + if not hasattr(submodule, '__origin_forward__'): + submodule.__origin_forward__ = submodule.forward + submodule.forward = MethodType(kernel_fn, submodule) + + elif isinstance(target_modules, list): + for name, submodule in module.named_modules(): + if any(name.endswith(target) for target in target_modules): + if not hasattr(submodule, '__origin_forward__'): + submodule.__origin_forward__ = submodule.forward + submodule.forward = MethodType(kernel_fn, submodule) + return module diff --git a/src/twinkle/uploader/__init__.py b/src/twinkle/uploader/__init__.py deleted file mode 100644 index e69de29b..00000000 From 99d58267e8434e5b28ae8deb8a65c51948bda87f Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 19 Dec 2025 15:09:54 +0800 Subject: [PATCH 004/692] wip --- pyproject.toml | 38 +++ src/twinkle/infra/ray/base.py | 365 ++++++++++++++++++++++ src/twinkle/infra/ray/resource_manager.py | 138 ++++++++ 3 files changed, 541 insertions(+) create mode 100644 pyproject.toml create mode 100644 src/twinkle/infra/ray/base.py create mode 100644 src/twinkle/infra/ray/resource_manager.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..11f1fe33 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[tool.poetry.dependencies] +python = "^3.11" +datasets = "^3.0" +binpacking = "*" +importlib_metadata = "*" +numpy = "*" +omegaconf = "*" +accelerate = {optional = true} +peft = {version = "^0.11.0", optional = true} +pillow = {optional = true} +rouge = {optional = true} +safetensors = {optional = true} +scipy = {optional = true} +sentencepiece = {optional = true} +ray = {optional = true} +transformers = {optional = true} +megatron-core = {version = "^0.12.0", optional = true} +torch = {version = "^2.0.0", optional = true} +torchvision = {optional = true} +deepspeed = {optional = true} +ray = {optional = true} +sphinx = {version = "^5.3.0", optional = true} +docutils = {version = "^0.16.0", optional = true} +myst_parser = {optional = true} +recommonmark = {optional = true} +sphinx-book-theme = {optional = true} +sphinx-copybutton = {optional = true} +sphinx-rtd-theme = {optional = true} +sphinx_markdown_tables = {optional = true} +sphinxcontrib-mermaid = {optional = true} + +[tool.poetry.extras] +transformers = ["accelerate", "peft", "transformers", "safetensors", "torch", "torchvision"] +megatron = ["megatron-core"] +ray = ["ray"] +docs = ["sphinx", "docutils", "myst_parser", "recommonmark", "sphinx-book-theme", "sphinx-copybutton", "sphinx-rtd-theme", "sphinx_markdown_tables", "sphinxcontrib-mermaid"] + + diff --git a/src/twinkle/infra/ray/base.py b/src/twinkle/infra/ray/base.py new file mode 100644 index 00000000..7f9ae4eb --- /dev/null +++ b/src/twinkle/infra/ray/base.py @@ -0,0 +1,365 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import functools +import inspect +import os +from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union + +import json +import numpy as np + +from .resource_manager import ResourceManager + +T = TypeVar('T') + + +def get_args(): + parser = argparse.ArgumentParser() + _, unknown = parser.parse_known_args() + return json.dumps(unknown) + + +class RayHelper: + resource_manager: Optional[ResourceManager] = None + + worker_cls: Dict = {} + + worker_instance: Dict = {} + + initialized = False + + device_groups: Dict[str, Any] = None + + @staticmethod + def initialize(device_groups: Dict[str, Any]): + """Initialize RayHelper. + + Args: + device_groups: The device groups to initialize. + + Returns: + None + """ + if RayHelper.ray_inited(): + return + import ray + RayHelper.device_groups = device_groups + ray.init() + if RayHelper.resource_manager is None: + # Resource manager initialize only once in the pipeline process. + RayHelper.resource_manager = ResourceManager(device_groups) + + @staticmethod + def teardown(): + if RayHelper.resource_manager is not None: + RayHelper.resource_manager.destroy_placement_group() + RayHelper.resource_manager = None + + @staticmethod + def is_called_from_init(): + """If some function called from __init__. + + Ray functions perform different behaviors depending on whether they are called from __init__. + + Returns: + Boolean. + """ + stack = inspect.stack() + for frame_info in stack[1:]: + if frame_info.function == '__init__': + return True + return False + + @staticmethod + def ray_inited(): + try: + import ray + except ImportError: + # not installed, not inited + return False + return ray.is_initialized() + + @staticmethod + def is_worker(): + import ray + return RayHelper.ray_inited() and ray._private.worker.global_worker.mode == ray._private.worker.WORKER_MODE + + @staticmethod + def worker(group: Union[str, List[str]]): + + def decorator(cls): + if not RayHelper.ray_inited(): + return cls + if RayHelper.is_worker(): + return cls + cls.decorated = True + groups = [group] if isinstance(group, str) else group + import ray + _cls = ray.remote(cls) + for g in groups: + RayHelper.worker_cls[g] = _cls + + init_method = cls.__init__ + + @functools.wraps(init_method) + def new_init(self, *args, **kwargs): + if not RayHelper.is_worker(): + # Create remote workers + RayHelper._create_workers(group, *args, **kwargs) + init_method(self, *args, **kwargs) + + cls.__init__ = new_init + + return cls + + return decorator + + @staticmethod + def collect_func(method: Union[Literal['none', 'flatten'], Callable], result): + if isinstance(result[0], tuple): + output = [] + for i in range(len(result[0])): + _single_result = [r[i] for r in result] + output.append(RayHelper.collect_func(method, _single_result)) + return output + if method == 'none': + return result + elif method == 'flatten': + flatten = [item for sublist in result for item in sublist] + if isinstance(result[0], np.ndarray): + return np.array(flatten) + return type(result[0])(flatten) + elif isinstance(method, Callable): + # Callable + return method(result) + else: + raise ValueError(f'Unsupported collect method: {method}') + + @staticmethod + def function(group: str, + dispatch: Union[Literal['slice', 'all'], Callable] = 'all', + execute: Literal['first', 'all'] = 'all', + collect: Union[Literal['none', 'flatten'], Callable] = 'none'): + """Remote execution function. + + Args: + group: The group to execute. + dispatch: How to dispatch the arguments. + 'slice': load balance + 'all': all processes do the same thing + execute: How to execute + 'first': Only first worker + 'all': All processes + collect: How to collect the results. + 'none': Return as-is + 'flatten': Return a flattened list + Returns: + The execution result. + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + + @functools.wraps(func) + def wrapper(self, *args, **kwargs) -> T: + if not RayHelper.ray_inited(): + return func(self, *args, **kwargs) + if RayHelper.is_worker(): + if not hasattr(self, 'group'): + # pass through env + self.group = os.environ['RAY_SWIFT_GROUP'].split(',') + if group not in self.group: + if RayHelper.is_called_from_init(): + # Functions in init of different group, do nothing + return None + else: + # Should not happen + raise ValueError() + else: + return func(self, *args, **kwargs) + else: + if RayHelper.is_called_from_init(): + # each worker do its own init + return None + result = RayHelper.execute_all_sync(group, dispatch, execute, func.__name__, *args, **kwargs) + return RayHelper.collect_func(collect, result) + + return wrapper + + return decorator + + @staticmethod + def execute_all_sync(group, dispatch, execute, method_name: str, *args, **kwargs): + import ray + return ray.get(RayHelper.execute_all_async(group, dispatch, execute, method_name, *args, **kwargs)) + + @staticmethod + def execute_all_async(group, dispatch, execute, method_name: str, *args, **kwargs): + workers = RayHelper.worker_instance[group] + length = len(workers) + if execute == 'first': + return getattr(workers[0], method_name).remote(*args, **kwargs) + elif dispatch == 'all': + return [getattr(worker, method_name).remote(*args, **kwargs) for worker in workers] + elif dispatch == 'slice': + result = [] + + def dispatch_func(arg, n): + if isinstance(arg, list): + k, m = divmod(len(arg), n) + return [arg[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)] + else: + return [arg] * n + + args = [dispatch_func(arg, length) for arg in args] + kwargs = {k: dispatch_func(v, length) for k, v in kwargs.items()} + for i in range(length): + sliced_args = tuple(arg[i] for arg in args) + sliced_kwargs = {k: v[i] for k, v in kwargs.items()} + if (sliced_args and sliced_args[0]) or (kwargs and list(kwargs.values())): + # skip empty input + remote_call = getattr(workers[i], method_name) + result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) + return result + elif isinstance(dispatch, Callable): + # dispatch is Callable + result = [] + for i in range(length): + sliced_args, sliced_kwargs = dispatch(length, i, *args, **kwargs) + remote_call = getattr(workers[i], method_name) + result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) + return result + else: + raise ValueError(f'Invalid dispatch method: {dispatch}') + + @staticmethod + def _create_workers(group: Union[str, List[str]], *args, **kwargs): + import ray + from ray.runtime_env import RuntimeEnv + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + exp_name = os.environ.get('RAY_SWIFT_EXP_NAME') + if not exp_name: + exp_name = '' + else: + exp_name += '-' + + if isinstance(group, str): + group = [group] + + for _group in group: + if _group in RayHelper.worker_instance: + continue + + worker_cls = RayHelper.worker_cls[_group] + + _config = None + for name, config in RayHelper.device_groups.items(): + if name in RayHelper.resource_manager.possible_keys: + continue + + if _group in config['workers']: + _config = config + break + + assert _config is not None + local_groups = _config['workers'] + + VISIBLE_ENV_MAPPING = { + 'GPU': 'CUDA_VISIBLE_DEVICES', + 'NPU': 'ASCEND_VISIBLE_DEVICES', + } + + if _config['device'].upper() != 'CPU': + world_size = len(_config['ranks']) + placement_groups: List[List[Dict]] = RayHelper.resource_manager.resource(_group) + workers = [] + ip, port = None, None + for rank, (deploy_pg, gpu) in enumerate(zip(placement_groups, _config['ranks'])): + deploy_pg: Dict + cluster_name = exp_name + '-'.join(local_groups) + worker_name = cluster_name + '-' + str(rank) + env_vars = os.environ.copy() + env_vars.update({ + 'WORLD_SIZE': + str(world_size), + 'RANK': + str(rank), + 'LOCAL_RANK': + str(0), + 'CLUSTER_NAME': + cluster_name, + 'WORKER_NAME': + worker_name, + VISIBLE_ENV_MAPPING[_config['device'].upper()]: + ','.join([str(r) for r in deploy_pg['gpu_rank']]), # TODO npu + 'RAY_SWIFT_ARGS': + get_args(), # pass through env + }) + + @ray.remote + def get_node_address(): + return find_node_ip(), find_free_port() + + if rank == 0: + ip, port = ray.get( + get_node_address.options(placement_group=deploy_pg['placement_group']).remote()) + + env_vars['MASTER_ADDR'] = ip + env_vars['MASTER_PORT'] = str(port) + env_vars['RAY_SWIFT_GROUP'] = ','.join(local_groups) + + runtime_env = RuntimeEnv(env_vars=env_vars) + + worker_options = { + 'scheduling_strategy': + PlacementGroupSchedulingStrategy(placement_group=deploy_pg['placement_group']), + 'name': + worker_name, + 'namespace': + 'default', + 'runtime_env': + runtime_env, + 'num_cpus': + 0.01, + 'num_gpus': + 0.01, + } + + worker = worker_cls.options(**worker_options).remote(*args, **kwargs) + workers.append(worker) + else: + world_size = _config['ranks'] + placement_groups: List[List[Dict]] = RayHelper.resource_manager.resource(_group) + workers = [] + for deploy_pg, index in zip(placement_groups, list(range(world_size))): + deploy_pg: Dict + cluster_name = '-'.join(local_groups) + worker_name = cluster_name + '-' + str(index) + env_vars = os.environ.copy() + env_vars.update({ + 'CLUSTER_NAME': cluster_name, + 'WORKER_NAME': worker_name, + VISIBLE_ENV_MAPPING[_config['device'].upper()]: '', + 'RAY_SWIFT_ARGS': get_args(), # pass through env + }) + env_vars['RAY_SWIFT_GROUP'] = ','.join(local_groups) + + runtime_env = RuntimeEnv(env_vars=env_vars) + + worker_options = { + 'scheduling_strategy': + PlacementGroupSchedulingStrategy(placement_group=deploy_pg['placement_group']), + 'name': + worker_name, + 'namespace': + 'default', + 'runtime_env': + runtime_env, + 'num_cpus': + 0.01, + } + + worker = worker_cls.options(**worker_options).remote(*args, **kwargs) + workers.append(worker) + + for g in local_groups: + RayHelper.worker_instance[g] = workers diff --git a/src/twinkle/infra/ray/resource_manager.py b/src/twinkle/infra/ray/resource_manager.py new file mode 100644 index 00000000..885dc6b4 --- /dev/null +++ b/src/twinkle/infra/ray/resource_manager.py @@ -0,0 +1,138 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Some code borrowed from ROLL: https://github.com/alibaba/ROLL +import ast +import math +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List + + +@dataclass +class NodeGroup: + device_count: int + nodes: List[Any] = field(default_factory=list) + + +def get_node_rank(): + return int(os.environ.get('NODE_RANK', '0')) + + +class ResourceManager: + + possible_keys = ['nproc_per_node', 'nnodes'] + + def __init__(self, groups: Dict[str, Any]): + import ray + from ray.util.placement_group import PlacementGroup + nproc_per_node = int(groups['nproc_per_node']) + device_types = set([group['device'].upper() + for group in groups.values() if hasattr(group, '__getitem__')]) - {'CPU'} + assert len(device_types) == 1 + device_type = next(iter(device_types)) + all_ranks = [] + last_rank = -1 + cpu_proc_count = 0 + for group_name, group in groups.items(): + if group_name in self.possible_keys: + continue + ranks = group['ranks'] + device = group['device'].upper() + if device == 'CPU': + assert isinstance(ranks, int), 'CPU group only supports integer ranks' + cpu_proc_count += ranks + continue + try: + ranks = int(ranks) # int type + ranks = list(range(last_rank + 1, last_rank + 1 + ranks)) + except Exception: # noqa + if isinstance(ranks, str): + ranks = eval(ranks, {'__builtins__': {'list': list, 'range': range}}) + finally: + all_ranks.extend(ranks) + group['ranks'] = ranks + last_rank = ranks[-1] + + assert len(set(all_ranks)) == len(all_ranks) + groups['nnodes'] = math.ceil(len(all_ranks) / nproc_per_node) + + self.nodes = [] + for node in ray.nodes(): + resource = node['Resources'] + node_gpu_num = int(resource.get(device_type, 0)) + if node_gpu_num >= nproc_per_node: + self.nodes.append(node) + + bundles = [] + cpu_bundles = [] + for i in range(groups['nnodes']): + node = self.nodes[i] + node_cpu = int(node['Resources']['CPU']) + bundles.append({device_type: nproc_per_node, 'CPU': node_cpu // 2 + 1}) + cpu_bundles.append({'CPU': node_cpu // 4 + 1}) # TODO dynamic scheduling + + nproc_cpu_per_node = cpu_proc_count // len(cpu_bundles) + 1 + self.cpu_node_map = {} + for i in range(cpu_proc_count): + node_idx = i // nproc_cpu_per_node + cpu_cnt = cpu_bundles[node_idx]['CPU'] + self.cpu_node_map[i] = (node_idx, cpu_cnt // nproc_cpu_per_node) + + self.placement_groups = [ray.util.placement_group([bundle]) for bundle in bundles] + self.cpu_placement_groups = [ray.util.placement_group([bundle]) for bundle in cpu_bundles] + cpu_bundles.sort(key=lambda bundle: bundle['CPU'], reverse=True) + ray.get([pg.ready() for pg in self.placement_groups]) + ray.get([pg.ready() for pg in self.cpu_placement_groups]) + + self.node_ranks = ray.get( + [ray.remote(get_node_rank).options(placement_group=pg).remote() for pg in self.placement_groups]) + if self.node_ranks.count(0) > 1: + self.node_ranks = list(range(len(self.placement_groups))) + + self.node2pg: Dict[int, PlacementGroup] = {} + for node_rank, placement_group in zip(self.node_ranks, self.placement_groups): + self.node2pg[node_rank] = placement_group + + self.device_groups = {} + ray_address = str(ray.get_runtime_context().gcs_address) + for group_name, group in groups.items(): + if group_name in self.possible_keys: + continue + + if group['device'] != 'CPU': + ranks = group['ranks'] + local_device_groups = [] + for rank in ranks: + node_rank = rank // nproc_per_node + gpu_rank = rank % nproc_per_node + local_device_groups.append( + dict( + node_rank=node_rank, + gpu_rank=[gpu_rank], + placement_group=self.node2pg[node_rank], + ray_address=ray_address)) + for worker in group['workers']: + self.device_groups[worker] = local_device_groups + else: + ranks = group['ranks'] + local_device_groups = [] + global_cpu_proc_idx = 0 + for _ in range(ranks): + local_device_groups.append( + dict( + placement_group=self.cpu_placement_groups[self.cpu_node_map[global_cpu_proc_idx][0]], + ray_address=ray_address)) + global_cpu_proc_idx += 1 + for worker in group['workers']: + self.device_groups[worker] = local_device_groups + + self.groups = groups + + def resource(self, worker): + return self.device_groups[worker] + + def destroy_placement_group(self): + import ray + for pg in self.placement_groups: + ray.util.remove_placement_group(pg) + for pg in self.cpu_placement_groups: + ray.util.remove_placement_group(pg) From a48570cd546befec47f2f93b9701a2ccf19c4aa3 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 19 Dec 2025 17:39:30 +0800 Subject: [PATCH 005/692] wip --- src/twinkle/__init__.py | 4 +- src/twinkle/dataset/llm.py | 934 ++++++++ src/twinkle/dataset/mllm.py | 1327 +++++++++++ src/twinkle/kernel/__init__.py | 7 +- src/twinkle/loss/__init__.py | 25 + src/twinkle/loss/base.py | 6 + src/twinkle/loss/chunked_cross_entropy.py | 55 + src/twinkle/loss/contrastive_loss.py | 23 + src/twinkle/loss/cosine_similarity.py | 11 + src/twinkle/loss/cross_entropy.py | 8 + src/twinkle/loss/generative_reranker.py | 56 + src/twinkle/loss/infonce.py | 263 +++ .../loss/listwise_generative_reranker.py | 114 + src/twinkle/loss/listwise_reranker.py | 90 + src/twinkle/loss/mse.py | 8 + src/twinkle/loss/online_contrastive_loss.py | 23 + src/twinkle/loss/reranker.py | 12 + src/twinkle/metric/__init__.py | 0 src/twinkle/patch/__init__.py | 0 src/twinkle/preprocessor/core.py | 549 +++++ src/twinkle/preprocessor/extra.py | 112 + src/twinkle/template/base.py | 2049 +++++++++++++++++ src/twinkle/template/constant.py | 247 ++ src/twinkle/template/grounding.py | 91 + src/twinkle/template/register.py | 65 + src/twinkle/template/template/__init__.py | 3 + src/twinkle/template/template/baai.py | 204 ++ src/twinkle/template/template/baidu.py | 255 ++ src/twinkle/template/template/bert.py | 5 + src/twinkle/template/template/deepseek.py | 479 ++++ src/twinkle/template/template/dots.py | 62 + src/twinkle/template/template/gemma.py | 243 ++ src/twinkle/template/template/glm.py | 533 +++++ src/twinkle/template/template/idefics3.py | 37 + src/twinkle/template/template/internlm.py | 195 ++ src/twinkle/template/template/internvl.py | 365 +++ src/twinkle/template/template/kwai.py | 302 +++ src/twinkle/template/template/llama.py | 216 ++ src/twinkle/template/template/llava.py | 408 ++++ src/twinkle/template/template/llm.py | 409 ++++ src/twinkle/template/template/megrez.py | 97 + src/twinkle/template/template/microsoft.py | 209 ++ src/twinkle/template/template/midashenglm.py | 67 + src/twinkle/template/template/minicpm.py | 312 +++ src/twinkle/template/template/minimax.py | 132 ++ src/twinkle/template/template/mistral.py | 167 ++ src/twinkle/template/template/molmo.py | 68 + src/twinkle/template/template/moonshot.py | 100 + src/twinkle/template/template/mplug.py | 221 ++ src/twinkle/template/template/openbuddy.py | 48 + src/twinkle/template/template/pixtral.py | 68 + src/twinkle/template/template/qwen.py | 1062 +++++++++ src/twinkle/template/template/seed.py | 264 +++ src/twinkle/template/template/stepfun.py | 299 +++ src/twinkle/template/template/utils.py | 75 + src/twinkle/template/template/valley.py | 135 ++ src/twinkle/template/template/yi.py | 63 + src/twinkle/template/template_inputs.py | 341 +++ src/twinkle/template/template_meta.py | 139 ++ src/twinkle/template/utils.py | 158 ++ src/twinkle/template/vision_utils.py | 318 +++ src/twinkle/utils/__init__.py | 4 +- 62 files changed, 14134 insertions(+), 8 deletions(-) create mode 100644 src/twinkle/dataset/llm.py create mode 100644 src/twinkle/dataset/mllm.py create mode 100644 src/twinkle/loss/base.py create mode 100644 src/twinkle/loss/chunked_cross_entropy.py create mode 100644 src/twinkle/loss/contrastive_loss.py create mode 100644 src/twinkle/loss/cosine_similarity.py create mode 100644 src/twinkle/loss/cross_entropy.py create mode 100644 src/twinkle/loss/generative_reranker.py create mode 100644 src/twinkle/loss/infonce.py create mode 100644 src/twinkle/loss/listwise_generative_reranker.py create mode 100644 src/twinkle/loss/listwise_reranker.py create mode 100644 src/twinkle/loss/mse.py create mode 100644 src/twinkle/loss/online_contrastive_loss.py create mode 100644 src/twinkle/loss/reranker.py create mode 100644 src/twinkle/metric/__init__.py create mode 100644 src/twinkle/patch/__init__.py create mode 100644 src/twinkle/preprocessor/core.py create mode 100644 src/twinkle/preprocessor/extra.py create mode 100644 src/twinkle/template/base.py create mode 100644 src/twinkle/template/constant.py create mode 100644 src/twinkle/template/grounding.py create mode 100644 src/twinkle/template/register.py create mode 100644 src/twinkle/template/template/__init__.py create mode 100644 src/twinkle/template/template/baai.py create mode 100644 src/twinkle/template/template/baidu.py create mode 100644 src/twinkle/template/template/bert.py create mode 100644 src/twinkle/template/template/deepseek.py create mode 100644 src/twinkle/template/template/dots.py create mode 100644 src/twinkle/template/template/gemma.py create mode 100644 src/twinkle/template/template/glm.py create mode 100644 src/twinkle/template/template/idefics3.py create mode 100644 src/twinkle/template/template/internlm.py create mode 100644 src/twinkle/template/template/internvl.py create mode 100644 src/twinkle/template/template/kwai.py create mode 100644 src/twinkle/template/template/llama.py create mode 100644 src/twinkle/template/template/llava.py create mode 100644 src/twinkle/template/template/llm.py create mode 100644 src/twinkle/template/template/megrez.py create mode 100644 src/twinkle/template/template/microsoft.py create mode 100644 src/twinkle/template/template/midashenglm.py create mode 100644 src/twinkle/template/template/minicpm.py create mode 100644 src/twinkle/template/template/minimax.py create mode 100644 src/twinkle/template/template/mistral.py create mode 100644 src/twinkle/template/template/molmo.py create mode 100644 src/twinkle/template/template/moonshot.py create mode 100644 src/twinkle/template/template/mplug.py create mode 100644 src/twinkle/template/template/openbuddy.py create mode 100644 src/twinkle/template/template/pixtral.py create mode 100644 src/twinkle/template/template/qwen.py create mode 100644 src/twinkle/template/template/seed.py create mode 100644 src/twinkle/template/template/stepfun.py create mode 100644 src/twinkle/template/template/utils.py create mode 100644 src/twinkle/template/template/valley.py create mode 100644 src/twinkle/template/template/yi.py create mode 100644 src/twinkle/template/template_inputs.py create mode 100644 src/twinkle/template/template_meta.py create mode 100644 src/twinkle/template/utils.py create mode 100644 src/twinkle/template/vision_utils.py diff --git a/src/twinkle/__init__.py b/src/twinkle/__init__.py index 4c9dee08..2f2d0370 100644 --- a/src/twinkle/__init__.py +++ b/src/twinkle/__init__.py @@ -3,12 +3,12 @@ if TYPE_CHECKING: from .version import __version__, __release_datetime__ - from .utils import framework, torch, requires, exists + from .utils import framework_util, torch_util, requires, exists else: _import_structure = { 'version': ['__release_datetime__', '__version__'], - 'utils': ['framework', 'torch', 'requires'], + 'utils': ['framework_util', 'torch_util', 'requires'], } import sys diff --git a/src/twinkle/dataset/llm.py b/src/twinkle/dataset/llm.py new file mode 100644 index 00000000..588a4ee6 --- /dev/null +++ b/src/twinkle/dataset/llm.py @@ -0,0 +1,934 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import ast +import re +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +import json +import numpy as np + +from ...template import split_str_parts_by +from ..preprocessor import (AlpacaPreprocessor, ClsGenerationPreprocessor, ClsPreprocessor, MessagesPreprocessor, + ResponsePreprocessor, RowPreprocessor, TextGenerationPreprocessor) +from ..register import DatasetMeta, SubsetDataset, register_dataset + + +class AlpacaZhPreprocessor(AlpacaPreprocessor): + + @classmethod + def concat_inst_input(cls, instruction, input_): + if input_ and input_.startswith('输入:'): + input_ = input_[3:] + return super().concat_inst_input(instruction, input_) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/alpaca-gpt4-data-zh', + hf_dataset_id='llm-wizard/alpaca-gpt4-data-zh', + preprocess_func=AlpacaZhPreprocessor(), + tags=['chat', 'general', '🔥'], + )) + + +class LongAlpacaPreprocessor(AlpacaPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + response = row['response'] + prefix_prompt = 'Answer: ' + if response and response.startswith(prefix_prompt): + response = response[len(prefix_prompt):].strip() + row['output'] = response + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/LongAlpaca-12k', + hf_dataset_id='Yukang/LongAlpaca-12k', + preprocess_func=LongAlpacaPreprocessor(), + tags=['long-sequence', 'QA'], + )) + + +class RuozhibaPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + title = row['title'] if row.get('title', None) is not None else row['content'] + abs = row['abs'] if 'abs' in row else None + if abs and abs != title: + title = title + ',' + abs + + pattern = r'\d+[\.,\s,\、](.+)' + match = re.search(pattern, title) + if match: + title = match.group(1) + if title: + return {'messages': [{'role': 'assistant', 'content': title}]} + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/ruozhiba', + subsets=['post-annual', 'title-good', 'title-norm'], + preprocess_func=RuozhibaPreprocessor(), + tags=['pretrain', '🔥'])) + + +class MathTrnPreprocessor(ResponsePreprocessor): + + def preprocess(self, row): + query = row['query'] + output = row['response'] + row = { + 'query': query, + 'response': output, + } + return super().preprocess(row) + + +register_dataset( + DatasetMeta(ms_dataset_id='AI-ModelScope/math-trn-format', preprocess_func=MathTrnPreprocessor(), tags=['math'])) + + +def _repair_ms_bench(messages: str) -> Optional[List[Dict[str, str]]]: + if isinstance(messages, str): + messages = ast.literal_eval(messages) + default_system = 'You are a helpful assistant.' + messages: List[Dict[str, str]] + if messages[0]['from'] == 'system' and messages[0]['value'] == default_system: + messages.pop(0) + # skip MOSS + for c in messages: + value = c['value'].lower() + if 'moss' in value or 'human:' in value or 'assistant:' in value or 'user:' in value: + return + return messages + + +register_dataset( + DatasetMeta( + ms_dataset_id='iic/ms_bench', + preprocess_func=MessagesPreprocessor(repair_messages=_repair_ms_bench), + tags=['chat', 'general', 'multi-round', '🔥'])) + + +def _repair_agent_messages(messages: List[Dict[str, str]], use_mini: bool) -> Optional[List[Dict[str, str]]]: + if use_mini: + pattern = r'\d\. {"plugin_name": "(.+?)"' + if messages[0]['from'] != 'system': + return + system = messages[0]['value'] + find_list = re.findall(pattern, system) + if len(set(find_list)) <= 1: + return + return messages + + +register_dataset( + DatasetMeta( + ms_dataset_id='damo/MSAgent-Bench', + subsets=[ + SubsetDataset( + preprocess_func=MessagesPreprocessor(repair_messages=partial(_repair_agent_messages, use_mini=False))), + SubsetDataset( + name='mini', + preprocess_func=MessagesPreprocessor(repair_messages=partial(_repair_agent_messages, use_mini=True)), + is_weak_subset=True) + ], + split=['train', 'validation'], + tags=['chat', 'agent', 'multi-round'])) + +advertise_gen_prompt = """Task: Generating advertisements based on keywords. +Keywords: {{QUERY}} +Advertisements:""" + +register_dataset( + DatasetMeta( + ms_dataset_id='lvjianjin/AdvertiseGen', + hf_dataset_id='shibing624/AdvertiseGen', + preprocess_func=TextGenerationPreprocessor( + prompt=advertise_gen_prompt, columns={ + 'content': 'query', + 'summary': 'response' + }), + tags=['text-generation', '🔥'], + split=['train', 'validation'], + )) + + +class FireflyPreprocessor(ResponsePreprocessor): + _firefly_kind_list = { + 'ProseGeneration', 'MRC', 'JinYongGeneration', 'TextCorrection', 'ClassicalChinese', 'BELLE', 'StoryGeneration', + 'Couplet', 'Cot', 'Dictionary', 'Translation', 'Program', 'SentimentAnalyze', 'OpenQA', 'AncientPoem', + 'TextMatching', 'NLI', 'Summary', 'KeywordRecognition', 'ProductDesc', 'LyricGeneration', 'Composition', + 'MusicComment', 'NER' + } + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if row['kind'] not in FireflyPreprocessor._firefly_kind_list: + return + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/firefly-train-1.1M', + hf_dataset_id='YeungNLP/firefly-train-1.1M', + preprocess_func=FireflyPreprocessor(), + tags=['chat', 'general'], + )) + +register_dataset( + DatasetMeta( + ms_dataset_id='modelscope/clue', + hf_dataset_id='clue', + subsets=['cmnli'], + preprocess_func=ClsGenerationPreprocessor(['neutral', 'entailment', 'contradiction'], + task='Natural Language Inference', + is_pair_seq=True), + tags=['text-generation', 'classification'], + split=['train', 'validation'], + )) + +register_dataset( + DatasetMeta( + ms_dataset_id='DAMO_NLP/jd', + subsets=[ + SubsetDataset( + 'default', + 'default', + preprocess_func=ClsGenerationPreprocessor(['negative', 'positive'], + task='Sentiment Classification', + is_pair_seq=False)), + SubsetDataset( + 'cls', + 'default', + preprocess_func=ClsPreprocessor(columns={'sentence': 'query'}), + ), + ], + tags=['text-generation', 'classification', '🔥'], + split=['train', 'validation'], + )) + + +class SyntheticText2SqlPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + sql_prompt = row['sql_prompt'] + sql_context = row['sql_context'] + sql = row['sql'] + sql_explanation = row['sql_explanation'] + query = f'Sql Table information:\n{sql_context}\n{sql_prompt}' + response = f'Let\'s think step by step:\n{sql_explanation}\nSo the final sql is:\n{sql}' + return super().preprocess({'query': query, 'response': response}) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/synthetic_text_to_sql', + hf_dataset_id='gretelai/synthetic_text_to_sql', + preprocess_func=SyntheticText2SqlPreprocessor(), + tags=['nl2sql', 'en'])) + + +def _repair_toolbench(conversations: List[Dict[str, str]]) -> List[Dict[str, str]]: + assert len(conversations) == 2 + if conversations[1]['from'] in {'caller', 'conclusion'}: + conversations[1]['from'] = 'assistant' + return conversations + + +register_dataset( + DatasetMeta( + ms_dataset_id='shenweizhou/alpha-umi-toolbench-processed-v2', + subsets=['backbone', 'caller', 'planner', 'summarizer'], + preprocess_func=MessagesPreprocessor(repair_messages=_repair_toolbench), + tags=['chat', 'agent', '🔥'], + huge_dataset=True)) + + +class BlossomMathPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + output, answer = row['output'], row['answer'] + return super().preprocess({'query': row['query'], 'response': f'{output}\n\nAnswer: {answer}'}) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/blossom-math-v2', + hf_dataset_id='Azure99/blossom-math-v2', + preprocess_func=BlossomMathPreprocessor(), + tags=['chat', 'math', '🔥'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/sql-create-context', + hf_dataset_id='b-mc2/sql-create-context', + preprocess_func=AlpacaPreprocessor(columns={ + 'question': 'instruction', + 'context': 'input', + 'answer': 'output' + }), + tags=['chat', 'sql', '🔥'])) + + +class TigerBotLawPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + prompt = """{type} +{title} +""" + cur_prompt = prompt.format(type=row['type'], title=row['title']) + for i in range(1, 4): + chapter = row[f'chapter{i}'] + if chapter is not None: + cur_prompt += f'{chapter}' + cur_prompt += f'{row["response"]}' + return super().preprocess({'response': cur_prompt}) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/tigerbot-law-plugin', + hf_dataset_id='TigerResearch/tigerbot-law-plugin', + preprocess_func=TigerBotLawPreprocessor(), + tags=['text-generation', 'law', 'pretrained'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='codefuse-ai/CodeExercise-Python-27k', + preprocess_func=MessagesPreprocessor(columns={'chat_rounds': 'messages'}), + tags=['chat', 'coding', '🔥'])) + + +class LeetcodePythonPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + code_with_problem = row['code_with_problem'] + idx = code_with_problem.find('```python') + problem = code_with_problem[:idx] + if problem.startswith('# '): + problem = problem[2:] + code = code_with_problem[idx:].strip() + explanation = row['explanation_only'] + return super().preprocess({'query': problem, 'response': f'{code}\n\n{explanation}'}) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/leetcode-solutions-python', + preprocess_func=LeetcodePythonPreprocessor(), + tags=['chat', 'coding', '🔥'])) + + +class StsbPreprocessor(RowPreprocessor): + + def __init__(self, sim_threshold: Optional[float] = None): + self.sim_threshold = sim_threshold + super().__init__() + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row = { + 'messages': [{ + 'role': 'user', + 'content': row['sentence1'] + }], + 'positive_messages': [[{ + 'role': 'user', + 'content': row['sentence2'] + }]], + 'label': row['score'], + } + if self.sim_threshold is None or float(row['label']) >= self.sim_threshold: + return row + else: + return None + + +class StsbGeneratePreprocessor(ResponsePreprocessor): + prompt = """Task: Based on the given two sentences, provide a similarity score between 0.0 and 1.0. +Sentence 1: {text1} +Sentence 2: {text2} +Similarity score: """ + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + return super().preprocess({ + 'query': self.prompt.format(text1=row['sentence1'], text2=row['sentence2']), + 'response': f"{row['score']:.1f}" + }) + + +class StsbRegressionPreprocessor(StsbGeneratePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + return super(StsbGeneratePreprocessor, self).preprocess({ + 'query': + self.prompt.format(text1=row['sentence1'], text2=row['sentence2']), + 'label': + row['score'] + }) + + +register_dataset( + DatasetMeta( + ms_dataset_id='sentence-transformers/stsb', + hf_dataset_id='sentence-transformers/stsb', + subsets=[ + SubsetDataset('default', preprocess_func=StsbPreprocessor()), # embedding + SubsetDataset('positive', preprocess_func=StsbPreprocessor(sim_threshold=0.75)), # infonce + SubsetDataset('generate', preprocess_func=StsbGeneratePreprocessor()), + SubsetDataset('reg', preprocess_func=StsbRegressionPreprocessor()), + ], + tags=['similarity', '🔥'])) + + +class MTEBRerankPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + query = row['query'] + positives = row['positive'] if isinstance(row['positive'], list) else [row['positive']] + negatives = row['negative'] if isinstance(row['negative'], list) else [row['negative']] + + messages = [{'role': 'user', 'content': query}] + positive_messages = [[{'role': 'assistant', 'content': positive}] for positive in positives] + negative_messages = [[{'role': 'assistant', 'content': negative}] for negative in negatives] + + return {'messages': messages, 'positive_messages': positive_messages, 'negative_messages': negative_messages} + + +register_dataset( + DatasetMeta( + ms_dataset_id='MTEB/scidocs-reranking', + hf_dataset_id='mteb/scidocs-reranking', + split=['validation', 'test'], + preprocess_func=MTEBRerankPreprocessor(), + tags=['rerank', '🔥'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='MTEB/stackoverflowdupquestions-reranking', + hf_dataset_id='mteb/stackoverflowdupquestions-reranking', + split=['train', 'test'], + preprocess_func=MTEBRerankPreprocessor(), + tags=['rerank', '🔥'])) + + +def _repair_conversations_agent_instruct(s: str) -> List[Dict[str, Any]]: + s = s.replace('}\n {', '},\n {') + if isinstance(s, str): + s = ast.literal_eval(s) + return s + + +register_dataset( + DatasetMeta( + ms_dataset_id='huangjintao/AgentInstruct_copy', + subsets=['alfworld', 'db', 'kg', 'mind2web', 'os', 'webshop'], + preprocess_func=MessagesPreprocessor(repair_messages=_repair_conversations_agent_instruct), + tags=['chat', 'agent', 'multi-round'])) + + +class MultiRoleAgentPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + conv = row['conversations'] + res_prompt = '\n\n【注意事项】\n1. 这是聊天室,不要发送私信给任何人\n2. 仅代表你个人说话,不要扮演其他人,只根据对话历史进行回复\n3. 长话短说,不要说太多话,不要超过50字 ' + history_prompt = '\n\n【chat history】' + conv_prompt = '\n {name}:{content}' + query, response = '', conv[-1]['value'] + system = conv[0]['value'] if conv[0]['from'] == 'system' else '' + if conv[0]['from'] == 'user': + query = conv[0]['value'] + elif 'next_speakers:' not in system: + if '【注意事项】' not in system and system: + system += res_prompt + system += history_prompt + system += ''.join([conv_prompt.format(name=c['from'], content=c['value']) for c in conv[1:-1]]) + + if not query or not response: + return + + return { + 'messages': [{ + 'role': 'system', + 'content': system + }, { + 'role': 'user', + 'content': query + }, { + 'role': 'assistant', + 'content': response + }], + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='iic/MSAgent-MultiRole', + preprocess_func=MultiRoleAgentPreprocessor(), + tags=['chat', 'agent', 'multi-round', 'role-play', 'multi-agent'])) + +register_dataset(DatasetMeta(ms_dataset_id='swift/ToolBench', tags=['chat', 'agent', 'multi-round'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='tastelikefeet/competition_math', + subsets=[ + SubsetDataset( + name='default', + subset='default', + split=['train', 'test'], + ), + ], + tags=['qa', 'math'])) + +register_dataset(DatasetMeta(ms_dataset_id='modelscope/gsm8k', subsets=['main'], split=['train'], tags=['qa', 'math'])) + +register_dataset( + DatasetMeta(ms_dataset_id='modelscope/MathR', subsets=['default', 'clean'], split=['train'], tags=['qa', 'math'])) + +register_dataset( + DatasetMeta(ms_dataset_id='modelscope/MathR-32B-Distill', subsets=['data'], split=['train'], tags=['qa', 'math'])) + + +class CoundownTaskPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + numbers = row['nums'] + target = row.pop('response', None) + query = (f'Using the numbers {numbers}, create an equation that equals {target}.\n' + 'You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.\n' + 'Show your work in tags. And return the final equation and answer ' + 'in tags, for example (1 + 2) / 3 * 4 = 4 .') + row.update({'target': target, 'query': query}) + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='zouxuhong/Countdown-Tasks-3to4', + subsets=['default'], + preprocess_func=CoundownTaskPreprocessor(), + tags=['math'])) + + +class HC3Preprocessor(ResponsePreprocessor): + prompt = """Classification Task: Are the following responses from a human or from ChatGPT? +Question: {question} +Answer: {answer} +Category: Human, ChatGPT +Output:""" + + def preprocess(self, row): + rows = [] + for response in ['Human', 'ChatGPT']: + query = self.prompt.format( + question=row['query'], answer=self.random_state.choice(row[f'{response.lower()}_answers'])) + rows.append(super().preprocess({'query': query, 'response': response})) + return rows + + +class HC3ClsPreprocessor(HC3Preprocessor): + + def preprocess(self, row): + rows = [] + for i, response in enumerate(['Human', 'ChatGPT']): + query = self.prompt.format( + question=row['query'], answer=self.random_state.choice(row[f'{response.lower()}_answers'])) + rows.append(ResponsePreprocessor.preprocess(self, {'query': query, 'label': i})) + return rows + + +hc3_subset_names = ['baike', 'open_qa', 'nlpcc_dbqa', 'finance', 'medicine', 'law', 'psychology'] +hc3_subsets: List[SubsetDataset] = [] +for hc3_subset_name in hc3_subset_names: + hc3_subsets.append( + SubsetDataset( + name=hc3_subset_name, + subset=hc3_subset_name, + preprocess_func=HC3Preprocessor(), + )) + hc3_subsets.append( + SubsetDataset( + name=f'{hc3_subset_name}_cls', + subset=hc3_subset_name, + preprocess_func=HC3ClsPreprocessor(), + )) + +register_dataset( + DatasetMeta( + ms_dataset_id='simpleai/HC3-Chinese', + hf_dataset_id='Hello-SimpleAI/HC3-Chinese', + subsets=hc3_subsets, + tags=['text-generation', 'classification', '🔥'])) + +hc3_subset_names = ['finance', 'medicine'] +hc3_subsets: List[SubsetDataset] = [] +for hc3_subset_name in hc3_subset_names: + hc3_subsets.append( + SubsetDataset( + name=hc3_subset_name, + subset=hc3_subset_name, + preprocess_func=HC3Preprocessor(), + )) + hc3_subsets.append( + SubsetDataset( + name=f'{hc3_subset_name}_cls', + subset=hc3_subset_name, + preprocess_func=HC3ClsPreprocessor(), + )) + +register_dataset( + DatasetMeta( + ms_dataset_id='simpleai/HC3', + hf_dataset_id='Hello-SimpleAI/HC3', + subsets=hc3_subsets, + preprocess_func=HC3Preprocessor(), + tags=['text-generation', 'classification', '🔥'])) + + +class DureaderPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + prompt = """Task: Question Generation +Context: {context} +Answer: {answer} +Question:""" + answer, context = row['text1'].split('[SEP]') + return { + 'messages': [{ + 'role': 'user', + 'content': prompt.format(context=context, answer=answer) + }, { + 'role': 'assistant', + 'content': row['text2'] + }] + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='modelscope/DuReader_robust-QG', + preprocess_func=DureaderPreprocessor(), + split=['train', 'validation', 'test'], + tags=['text-generation', '🔥'])) + + +class HHRLHFPreprocessor(RowPreprocessor): + + @staticmethod + def _to_messages(data): + messages = [] + for query, response in zip(data[::2], data[1::2]): + messages.append({'role': 'user', 'content': query}) + messages.append({'role': 'assistant', 'content': response}) + return messages + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + chosen = row['chosen'].strip() + rejected = row['rejected'].strip() + parts_chosen = [s.strip() for s in re.split('\n\nHuman:|\n\nAssistant:|\n\nHum:', chosen)] + parts_rejected = [s.strip() for s in re.split('\n\nHuman:|\n\nAssistant:|\n\nHum:', rejected)] + if parts_chosen[0].startswith('Human:'): + assert parts_rejected[0].startswith('Human:') + parts_chosen[0] = parts_chosen[0][6:].strip() + parts_rejected[0] = parts_rejected[0][6:].strip() + row['messages'] = self._to_messages(parts_chosen) + row['rejected_messages'] = self._to_messages(parts_rejected) + return row + + +# TODO meta file broken +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/hh-rlhf', + subsets=['helpful-base', 'helpful-online', 'helpful-rejection-sampled'], + preprocess_func=HHRLHFPreprocessor(), + split=['train', 'test'], + tags=['rlhf', 'dpo'], + huge_dataset=True)) + + +class XlamFunctionCallingPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + messages = [{'role': 'user', 'content': row['query']}] + response = row['answers'] + response = json.loads(response) + messages += [{'role': 'tool_call', 'content': json.dumps(content)} for content in response] + return {'messages': messages, 'tools': row['tools']} + + +class XlamFunctionCallingGRPOPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + query = row['query'] + answers = row['response'] + if isinstance(answers, str): + answers = json.loads(answers) + answer = np.random.choice(answers) + name = answer['name'] + args = json.dumps(answer['arguments']) + response = f'Action: {name}\nAction Input: {args}' + row = {'query': query, 'response': response, 'solution': response, 'tools': row['tools']} + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='LLM-Research/xlam-function-calling-60k', + hf_dataset_id='Salesforce/xlam-function-calling-60k', + subsets=[ + SubsetDataset('default', 'dataset', preprocess_func=XlamFunctionCallingPreprocessor()), + SubsetDataset('grpo', 'dataset', preprocess_func=XlamFunctionCallingGRPOPreprocessor()) + ], + tags=['agent', 'grpo', '🔥'])) + + +class HHRLHFCNPreprocessor(MessagesPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['messages'].append(row.pop('chosen')) + row['rejected_response'] = row['rejected']['text'] + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/hh_rlhf_cn', + subsets=['hh_rlhf', 'harmless_base_cn', 'harmless_base_en', 'helpful_base_cn', 'helpful_base_en'], + preprocess_func=HHRLHFCNPreprocessor(columns={'context': 'messages'}, content_key='text'), + split=['train', 'test'], + tags=['rlhf', 'dpo', '🔥'])) + + +def repair_conversations(s: Union[str, Any]) -> Any: + if isinstance(s, str): + s = s.replace('}\n {', '},{') + s = s.replace('}\n{', '},{') + s = s.replace('}{', '},{') + s = s.replace('}\n {', '},{') + return ast.literal_eval(s) + return s + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/lmsys-chat-1m', + hf_dataset_id='lmsys/lmsys-chat-1m', + preprocess_func=MessagesPreprocessor(repair_messages=repair_conversations), + tags=['chat', 'em'])) + + +class EmojiPreprocessr(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + # Remove dirty characters + row['query'] = row['query'].replace('️', '') + row['response'] = row['response'].replace('️', '') + row['rejected_response'] = row['rejected_response'].replace('️', '') + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='hjh0119/shareAI-Llama3-DPO-zh-en-emoji', + hf_dataset_id='shareAI/DPO-zh-en-emoji', + preprocess_func=EmojiPreprocessr(columns={ + 'answer_zh': 'response', + 'answer_en': 'rejected_response' + }), + tags=['rlhf', 'dpo'])) + +register_dataset( + DatasetMeta(ms_dataset_id='AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto', tags=['rlhf', 'kto'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='OmniData/Zhihu-KOL-More-Than-100-Upvotes', + hf_dataset_id='bzb2023/Zhihu-KOL-More-Than-100-Upvotes', + tags=['zhihu', 'qa'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='OmniData/Zhihu-KOL', + hf_dataset_id='wangrui6/Zhihu-KOL', + huge_dataset=True, + tags=['zhihu', 'qa'], + )) + + +class GuanacoPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + instruction = row['instruction'] + input = row['input'] + output = row['output'] + history = [] + if instruction: + parts = split_str_parts_by( + instruction, ['User:', 'User:', 'Assistant:', 'Assistant:', 'Asssistent:', 'Assistent:', 'Assistenz:']) + for idx, part in enumerate(parts): + if idx % 2 == 0: + if 'user' not in part['key'].lower(): + return + history.append([part['content'], None]) + else: + if 'assist' not in part['key'].lower() and 'asssist' not in part['key'].lower(): + return + history[-1][-1] = part['content'] + if input.startswith('User:'): + input = input[len('User:'):].strip() + if any([not h[0] or not h[1] for h in history]): + return + + messages = [] + for h in history: + messages.append({'role': 'user', 'content': h[0]}) + messages.append({'role': 'assistant', 'content': h[1]}) + messages.append({'role': 'user', 'content': input}) + messages.append({'role': 'assistant', 'content': output}) + return { + 'messages': messages, + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/GuanacoDataset', + hf_dataset_id='JosephusCheung/GuanacoDataset', + preprocess_func=GuanacoPreprocessor(), + tags=['chat', 'zh'])) + + +class FunctionCallChatmlPreprocessor(MessagesPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + res = super().preprocess(row) + + if res['function_description']: + res['tools'] = res['function_description'].split('\n\n') + messages = res['messages'] + if messages[0]['role'] == 'system': + messages.pop(0) + return res + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/function-calling-chatml', + hf_dataset_id='Locutusque/function-calling-chatml', + preprocess_func=FunctionCallChatmlPreprocessor(), + tags=['agent', 'en', 'sft', '🔥'])) + + +class Dolly15kPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + instruction = row['instruction'] + context = row['context'] + response = row['response'] + query = '' + if context: + query = 'Here gives some useful information:\n' + query += context + query += '\n' + query += instruction + return { + 'messages': [{ + 'role': 'user', + 'content': query + }, { + 'role': 'assistant', + 'content': response + }], + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/databricks-dolly-15k', + hf_dataset_id='databricks/databricks-dolly-15k', + preprocess_func=Dolly15kPreprocessor(), + tags=['multi-task', 'en', 'quality'])) + + +class OrpoDPOMix40kPreprocessor(MessagesPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if row['source'] == 'toxic-dpo-v0.2': + return + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/orpo-dpo-mix-40k', + hf_dataset_id='mlabonne/orpo-dpo-mix-40k', + preprocess_func=OrpoDPOMix40kPreprocessor(columns={ + 'chosen': 'messages', + 'rejected': 'rejected_messages' + }), + tags=['dpo', 'orpo', 'en', 'quality'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/sharegpt', + subsets=['common-zh', 'unknow-zh', 'common-en'], + tags=['chat', 'general', 'multi-round'])) + + +class SelfCognitionPreprocessor(ResponsePreprocessor): + + def __init__(self, *args, query_suffix: str = '', response_prefix: str = '', **kwargs): + self.query_suffix = query_suffix + self.response_prefix = response_prefix + self.name: Optional[Tuple[str, str]] = None + self.author: Optional[Tuple[str, str]] = None + super().__init__(*args, **kwargs) + + def set_name_author(self, name, author): + self.name = name + self.author = author + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + for key in ['name', 'author']: + val = getattr(self, key) + if val is None: + continue + val = val[0] if row['tag'] == 'zh' else val[1] + if val is None: + continue + placeholder = '{{' + key.upper() + '}}' + row['query'] = row['query'].replace(placeholder, val) + row['response'] = row['response'].replace(placeholder, val) + + row['query'] = row['query'] + self.query_suffix + row['response'] = self.response_prefix + row['response'] + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/self-cognition', + hf_dataset_id='modelscope/self-cognition', + subsets=[ + SubsetDataset(preprocess_func=SelfCognitionPreprocessor()), + SubsetDataset( + 'qwen3', + preprocess_func=SelfCognitionPreprocessor( + query_suffix=' /no_think', response_prefix='\n\n\n\n')), + SubsetDataset( + 'empty_think', preprocess_func=SelfCognitionPreprocessor(response_prefix='\n\n\n\n')), + ], + dataset_name='self-cognition', + tags=['chat', 'self-cognition', '🔥'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='open-r1/DAPO-Math-17k-Processed', + hf_dataset_id='open-r1/DAPO-Math-17k-Processed', + subsets=['all'], + tags=['math', 'rlvr'])) diff --git a/src/twinkle/dataset/mllm.py b/src/twinkle/dataset/mllm.py new file mode 100644 index 00000000..87953705 --- /dev/null +++ b/src/twinkle/dataset/mllm.py @@ -0,0 +1,1327 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import ast +import os +from typing import Any, Dict, List, Optional + +import numpy as np +from datasets import Dataset as HfDataset +from datasets import IterableDataset as HfIterableDataset +from tqdm import tqdm + +from swift.utils import get_hf_endpoint, use_hf_hub +from ..media import MediaResource +from ..preprocessor import GroundingMixin, MessagesPreprocessor, ResponsePreprocessor, RowPreprocessor +from ..register import DatasetMeta, SubsetDataset, register_dataset + + +class ShareGPT4oPreprocessor(MessagesPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + row = super().preprocess(row) + image = row['images'] + if not image: + return + image = os.path.join(self.prefix_path, image) + if not os.path.exists(image): + return + row['images'] = [image] + return row + + def prepare_dataset(self, dataset): + if not use_hf_hub(): + url = ('https://www.modelscope.cn/api/v1/datasets/AI-ModelScope/ShareGPT-4o/repo?' + 'Revision=master&FilePath=images.zip') + else: + url = f'{get_hf_endpoint()}/datasets/OpenGVLab/ShareGPT-4o/blob/main/images.zip' + local_dir = MediaResource.download(url, 'sharegpt_4o_images') + self.prefix_path = os.path.join(local_dir, 'mnt', 'petrelfs', 'wangwenhai', 'workspace_cef', '4o', 'image') + return super().prepare_dataset(dataset) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/ShareGPT-4o', + hf_dataset_id='OpenGVLab/ShareGPT-4o', + preprocess_func=ShareGPT4oPreprocessor(), + subsets=['image_caption'], + split=['images'], + tags=['vqa', 'multi-modal'], + )) + + +class GPT4vDataset(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['query'] = 'What is the caption of this image?' + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/gpt4v-dataset', + hf_dataset_id='laion/gpt4v-dataset', + preprocess_func=GPT4vDataset(columns={ + 'link': 'images', + 'caption': 'response' + }), + split=['train'], + tags=['en', 'caption', 'multi-modal', 'quality'], + huge_dataset=True, + )) + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/RLAIF-V-Dataset', + hf_dataset_id='openbmb/RLAIF-V-Dataset', + preprocess_func=ResponsePreprocessor(columns={ + 'question': 'query', + 'chosen': 'response', + 'rejected': 'rejected_response' + }), + tags=['rlhf', 'dpo', 'multi-modal', 'en'], + )) + + +class GarbagePreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['query'] = 'Task: Classify household waste.' + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='tany0699/garbage265', + preprocess_func=GarbagePreprocessor(columns={ + 'category': 'label', + 'image:FILE': 'images' + }), + tags=['cls', '🔥', 'multi-modal'], + )) + + +class SA1BPairedCaptionPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + prompt = ['图片中展示了什么', '讲述一下图片中内容', '告诉我里面有什么', '图片内容是啥'] + response = row['global_caption'] + query = np.random.choice(prompt) + return { + 'messages': [{ + 'role': 'user', + 'content': query, + }, { + 'role': 'assistant', + 'content': response, + }] + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='Tongyi-DataEngine/SA1B-Paired-Captions-Images', + preprocess_func=SA1BPairedCaptionPreprocessor(columns={ + 'opensource_url': 'images', + }), + tags=['zh', 'multi-modal', 'vqa'], + )) + + +class SA1BDenseCaptionPreprocessor(RowPreprocessor): + column_mapping = { + 'url': 'images', + } + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + prompt = ['图片中展示了什么', '讲述一下图片中内容', '告诉我里面有什么', '图片内容是啥'] + response = ast.literal_eval(row['cap_seg']) + response = response.get('global_caption') + query = np.random.choice(prompt) + return { + 'messages': [{ + 'role': 'user', + 'content': query, + }, { + 'role': 'assistant', + 'content': response, + }] + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='Tongyi-DataEngine/SA1B-Dense-Caption', + preprocess_func=SA1BDenseCaptionPreprocessor(columns={ + 'url': 'images', + }), + tags=['zh', 'multi-modal', 'vqa'], + huge_dataset=True, + )) + + +class COCO2014Preprocess(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + caption = row['caption'] + if '&&' in caption: + caption = caption.split('&&')[0] + row['query'] = 'please describe the image.' + row['response'] = caption + + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='modelscope/coco_2014_caption', + preprocess_func=COCO2014Preprocess(), + subsets=[ + SubsetDataset('train', 'coco_2014_caption', ['train']), + SubsetDataset('validation', 'coco_2014_caption', ['validation']), + ], + tags=['chat', 'multi-modal', 'vision', '🔥'], + )) + + +class MantisPreprocessor(MessagesPreprocessor): + + def __init__(self, *, subset: str, columns: Optional[Dict[str, str]] = None) -> None: + self.subset = subset + super().__init__(columns=columns) + + def prepare_dataset(self, dataset: HfDataset) -> HfDataset: + if not use_hf_hub(): + url = (f'https://www.modelscope.cn/api/v1/datasets/swift/Mantis-Instruct/repo?Revision=' + f'master&FilePath={self.subset}/train_images.zip') # noqa + else: + url = (f'{get_hf_endpoint()}/datasets/TIGER-Lab/Mantis-Instruct/' + f'resolve/main/{self.subset}/train_images.zip') + self.local_dir = MediaResource.download(url, f'mantis_{self.subset}') + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + images = [os.path.join(self.local_dir, p['path']) for p in row['images']] + if not all([os.path.exists(d) for d in images]): + images = [] + + if not images: + return + row['images'] = images + return super().preprocess(row) + + +mantis_subsets_name = [ + 'birds-to-words', 'chartqa', 'coinstruct', 'contrastive_caption', 'docvqa', 'dreamsim', 'dvqa', 'iconqa', + 'imagecode', 'llava_665k_multi', 'lrv_multi', 'multi_vqa', 'nextqa', 'nlvr2', 'spot-the-diff', 'star', + 'visual_story_telling' +] + +_mantis_subsets = [] +for subset in mantis_subsets_name: + _subset = SubsetDataset(subset=subset, split=['train'], preprocess_func=MantisPreprocessor(subset=subset)) + _mantis_subsets.append(_subset) + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/Mantis-Instruct', + subsets=_mantis_subsets, + tags=['chat', 'multi-modal', 'vision'], + )) + + +class LLaVADataPreprocessor(MessagesPreprocessor): + + def prepare_dataset(self, dataset): + self.all_folders = {} + for media_type in ['coco', 'gqa', 'ocr_vqa', 'textvqa', 'VG_100K', 'VG_100K_2']: + self.all_folders[media_type] = MediaResource.download(media_type) + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if not row['images']: + return + row = super().preprocess(row) + images = [p['path'] for p in row['images']] + new_images = [] + for image in images: + if 'coco/' in image: + image = os.path.join(self.all_folders['coco'], image.replace('coco/', '')) + elif 'gqa/' in image: + image = os.path.join(self.all_folders['gqa'], image.replace('gqa/', '')) + elif 'ocr_vqa/' in image: + image = os.path.join(self.all_folders['ocr_vqa'], image) + elif 'textvqa/' in image: + image = os.path.join(self.all_folders['textvqa'], image.replace('textvqa/', '')) + elif 'VG_100K/' in image: + image = os.path.join(self.all_folders['VG_100K'], image.replace('vg/', '')) + elif 'VG_100K_2/' in image: + image = os.path.join(self.all_folders['VG_100K_2'], image.replace('vg/', '')) + new_images.append(image) + if all(os.path.exists(image) for image in new_images): + row['images'] = new_images + else: + return {'images': None} + return row + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/llava-data', + hf_dataset_id='TIGER-Lab/llava-data', + subsets=['llava_instruct'], + preprocess_func=LLaVADataPreprocessor(), + tags=['sft', 'multi-modal', 'quality'], + )) + + +class PixelProsePreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + caption_prompt = [ + 'Give the description of this image.', 'Describe this picture', 'What is the proper title of this image?' + ] + vlm_caption = row['vlm_caption'] + if vlm_caption.startswith('This image displays:'): + vlm_caption = vlm_caption[len('This image displays:'):].strip() + return { + 'messages': [{ + 'role': 'user', + 'content': np.random.choice(caption_prompt) + }, { + 'role': 'assistant', + 'content': vlm_caption + }], + 'images': + row['url'] + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/pixelprose', + hf_dataset_id='tomg-group-umd/pixelprose', + preprocess_func=PixelProsePreprocessor(), + split=['train', 'cc12m', 'commonpool', 'redcaps'], + tags=['caption', 'multi-modal', 'vision'], + huge_dataset=True, + )) + + +class AIShell1Preprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['query'] = '语音转文本' + row['response'] = row['Text:LABEL'].replace(' ', '') + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='speech_asr/speech_asr_aishell1_trainsets', + subsets=[ + SubsetDataset('train', split=['train']), + SubsetDataset('validation', split=['validation']), + SubsetDataset('test', split=['test']), + ], + preprocess_func=AIShell1Preprocessor(columns={'Audio:FILE': 'audios'}), + tags=['chat', 'multi-modal', 'audio'], + )) + + +class EmoSchemaPreprocessor(ResponsePreprocessor): + + def prepare_dataset(self, dataset: HfDataset) -> HfDataset: + for i in range(1, 6): + if not use_hf_hub(): + url = f'https://modelscope.cn/datasets/AI-ModelScope/egoschema/resolve/master/videos_chunked_0{i}.zip' + else: + url = f'{get_hf_endpoint()}/datasets/lmms-lab/egoschema/resolve/main/videos_chunked_0{i}.zip' + local_dir = MediaResource.download(url, 'egoschema') + + self.local_dir = os.path.join(local_dir, 'videos') + self.mp4_set = [file[:-4] for file in os.listdir(self.local_dir) if file.endswith('mp4')] + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if row['video_idx'] not in self.mp4_set: + return None + transfer_to_option = { + '0': 'A', + '1': 'B', + '2': 'C', + '3': 'D', + '4': 'E', + } + row = { + 'query': row['query'] + '\n' + '\n'.join(row['option']), + 'response': transfer_to_option[row['response']], + 'videos': [os.path.join(self.local_dir, f"{row['video_idx']}.mp4")], + } + return super().preprocess(row) + + +class EmoSchemaClsPreprocessor(EmoSchemaPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if row['video_idx'] not in self.mp4_set: + return None + row = { + 'query': row['query'] + '\n' + '\n'.join(row['option']), + 'label': int(row['response']), + 'videos': [os.path.join(self.local_dir, f"{row['video_idx']}.mp4")], + } + return ResponsePreprocessor.preprocess(self, row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/egoschema', + hf_dataset_id='lmms-lab/egoschema', + subsets=[ + SubsetDataset('default', 'Subset', preprocess_func=EmoSchemaPreprocessor()), + SubsetDataset('cls', 'Subset', preprocess_func=EmoSchemaClsPreprocessor()) + ], + split=['test'], + tags=['chat', 'multi-modal', 'video'], + )) + + +def _generate_url_list(_url, _range): + lst = [] + for i in range(1, (_range + 1)): + lst.append(_url.replace('{}', str(i))) + return lst + + +class LLaVAVideo178KPreprocessor(MessagesPreprocessor): + + def __init__(self, *, subset: str, columns: Optional[Dict[str, str]] = None) -> None: + self.subset = subset + super().__init__(columns=columns) + + url_prefix = 'https://www.modelscope.cn/datasets/lmms-lab/LLaVA-Video-178K/resolve/master/' + if use_hf_hub(): + url_prefix = f'{get_hf_endpoint()}/datasets/lmms-lab/LLaVA-Video-178K/resolve/main/' + + video_resources = { + '0_30_s_academic_v0_1': + _generate_url_list( + url_prefix + '0_30_s_academic_v0_1/0_30_s_academic_v0_1_videos_{}.tar.gz', + 8, + ), + '0_30_s_youtube_v0_1': + _generate_url_list( + url_prefix + '0_30_s_youtube_v0_1/0_30_s_youtube_v0_1_videos_{}.tar.gz', + 19, + ), + '1_2_m_academic_v0_1': + _generate_url_list( + url_prefix + '1_2_m_academic_v0_1/1_2_m_academic_v0_1_videos_{}.tar.gz', + 14, + ), + '1_2_m_youtube_v0_1': + _generate_url_list( + url_prefix + '1_2_m_youtube_v0_1/1_2_m_youtube_v0_1_videos_{}.tar.gz', + 50, + ), + '2_3_m_academic_v0_1': + _generate_url_list( + url_prefix + '2_3_m_academic_v0_1/2_3_m_academic_v0_1_videos_{}.tar.gz', + 18, + ), + '2_3_m_youtube_v0_1': + _generate_url_list( + url_prefix + '2_3_m_youtube_v0_1/2_3_m_youtube_v0_1_videos_{}.tar.gz', + 98, + ), + '30_60_s_academic_v0_1': + _generate_url_list( + url_prefix + '30_60_s_academic_v0_1/30_60_s_academic_v0_1_videos_{}.tar.gz', + 10, + ), + '30_60_s_youtube_v0_1': + _generate_url_list( + url_prefix + '30_60_s_youtube_v0_1/30_60_s_youtube_v0_1_videos_{}.tar.gz', + 13, + ), + } + + def prepare_dataset(self, dataset: HfDataset) -> HfDataset: + urls = self.video_resources[self.subset] + self.local_dir = MediaResource.download(urls, f'llava_video_178k_{self.subset}', file_type='sharded') + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + file_path = os.path.join(self.local_dir, f"{row['videos']}") + if not os.path.exists(file_path): + return None + return super().preprocess({'messages': row['messages'], 'videos': file_path}) + + +llava_video_subsets = [] +for subset in [ + '0_30_s_academic_v0_1', + '0_30_s_youtube_v0_1', + '1_2_m_academic_v0_1', + '1_2_m_youtube_v0_1', + '2_3_m_academic_v0_1', + '2_3_m_youtube_v0_1', + '30_60_s_academic_v0_1', + '30_60_s_youtube_v0_1', +]: + subset = SubsetDataset( + subset=subset, + split=['caption', 'open_ended', 'multi_choice'], + preprocess_func=LLaVAVideo178KPreprocessor(subset=subset), + ) + llava_video_subsets.append(subset) + +register_dataset( + DatasetMeta( + hf_dataset_id='lmms-lab/LLaVA-Video-178K', subsets=llava_video_subsets, tags=['chat', 'multi-modal', 'video'])) + + +class MovieChat1KPreprocessor(ResponsePreprocessor): + + def prepare_dataset(self, dataset: HfDataset) -> HfDataset: + mp4_set = [f'{i}.mp4' for i in range(1, 10)] + \ + [f'{i}.mp4' for i in range(201, 240)] + \ + [f'AWA-{i}.mp4' for i in range(1, 10)] + \ + [f'AWB-{i}.mp4' for i in range(1, 16)] + \ + [f'AWC-{i}.mp4' for i in range(1, 11)] + \ + [f'AWD-{i}.mp4' for i in range(1, 8)] + \ + [f'AWE-{i}.mp4' for i in range(1, 7)] + \ + [f'AWG-{i}.mp4' for i in range(1, 12)] + \ + [f'AWH-{i}.mp4' for i in range(1, 8)] + \ + [f'BWA-{i}.mp4' for i in range(1, 7)] + \ + [f'BWB-{i}.mp4' for i in range(1, 7)] + \ + [f'BWD-{i}.mp4' for i in range(1, 6)] + \ + [f'BWE-{i}.mp4' for i in range(1, 6)] + \ + [f'BWG-{i}.mp4' for i in range(1, 6)] + \ + [f'BWH-{i}.mp4' for i in range(1, 6)] + \ + [f'TFS-{i}.mp4' for i in range(1, 13)] + \ + [f'UWA-{i}.mp4' for i in range(1, 5)] + ['UWA-6.mp4'] + for file in mp4_set: + if not use_hf_hub(): + url = f'https://modelscope.cn/datasets/AI-ModelScope/MovieChat-1K-test/resolve/master/videos/{file}' + else: + url = f'{get_hf_endpoint()}/datasets/Enxin/MovieChat-1K-test/resolve/main/videos/{file}' + self.local_dir = MediaResource.download(url, 'moviechat_1k_test', file_type='file') + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + file_path = os.path.join(self.local_dir, f"{row['info']['video_path']}") + if not os.path.exists(file_path): + return None + return super().preprocess({ + 'query': row['global'][0]['question'], + 'response': row['global'][0]['answer'], + 'videos': file_path, + }) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/MovieChat-1K-test', + hf_dataset_id='Enxin/MovieChat-1K-test', + preprocess_func=MovieChat1KPreprocessor(), + split=['train'], + tags=['chat', 'multi-modal', 'video'])) + + +class VideoChatGPTPreprocessor(ResponsePreprocessor): + + def prepare_dataset(self, dataset: HfDataset) -> HfDataset: + if not use_hf_hub(): + url = 'https://modelscope.cn/datasets/swift/VideoChatGPT/resolve/master/videos.zip' + else: + url = f'{get_hf_endpoint()}/datasets/lmms-lab/VideoChatGPT/resolve/main/videos.zip' + local_dir = MediaResource.download(url, 'video_chatgpt') + self.local_dir = os.path.join(local_dir, 'Test_Videos') + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + # only `.mp4` + mp4_set = [file[:-4] for file in os.listdir(self.local_dir) if file.endswith('mp4')] + if row['video_name'] not in mp4_set: + return + row['videos'] = os.path.join(self.local_dir, f"{row['video_name']}.mp4") + for key in ['query', 'question_1', 'question_2']: + query = row.get(key) + if query is None or query == 'None': + continue + row['query'] = query + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/VideoChatGPT', + hf_dataset_id='lmms-lab/VideoChatGPT', + subsets=['Generic', 'Temporal', 'Consistency'], + preprocess_func=VideoChatGPTPreprocessor(), + split=['test'], + tags=['chat', 'multi-modal', 'video', '🔥'], + )) + + +def preprocess_mind2web(dataset, **kwargs): + + def preprocess_row(row: Dict[str, Any]) -> Dict[str, Any]: + raw_html = row['cleaned_html'] + screenshot = row['screenshot'] + row['screenshot'] = MediaResource.safe_save(screenshot, row['action_uid'] + '.jpg', 'mind2web') + action = row['target_action_reprs'] + actions = action.split('->') + row['query'] = f'The snapshot of screen:\nThe html source code:{raw_html}\n' + action = actions[-1] + where = actions[0] if len(actions) > 1 else '' + what = '' + if ':' in action: + action, what = action[:action.find(':')], action[action.find(':') + 1:] + row['response'] = f'Action: {action.strip()}\nAction Input: {where.strip()}{"," + what.strip()}' + return row + + conversations = [] + tools = [{ + 'function': { + 'name': 'CLICK', + 'desc': 'Choose and click an element in the web page', + 'parameter': [{ + 'element': 'string, the element in the web page to click' + }] + } + }, { + 'function': { + 'name': + 'TYPE', + 'desc': + 'Input some text into a web element like or ', + 'parameter': [{ + 'element': 'string, the element in the web page to input to', + 'content': 'string, what content to input into the textbox element' + }] + } + }, { + 'function': { + 'name': + 'SELECT', + 'desc': + 'Select an element from a combobox', + 'parameter': [{ + 'element': 'string, the combobox or dropdown in the web page on which the select happens', + 'content': 'string, which choices to choose' + }] + } + }] + + def history_to_messages(history): + messages = [] + for h in history: + messages.append({'role': 'user', 'content': h[0]}) + messages.append({'role': 'assistant', 'content': h[1]}) + return messages + + if isinstance(dataset, HfIterableDataset): + + def generate_example(dataset): + history = [] + images = [] + for row in dataset: + target_action_index = row['target_action_index'] + row = preprocess_row(row) + query = row['query'] + if target_action_index == '0': + if history: + yield {'messages': history_to_messages(history), 'images': images, 'tools': tools} + images = [] + history = [] + query = query + '\n' + row['confirmed_task'] + history.append([query, row['response']]) + images.append(row['screenshot']) + + if history: + yield {'messages': history_to_messages(history), 'images': images, 'tools': tools} + + return HfIterableDataset.from_generator(generate_example, gen_kwargs={'dataset': dataset}) + + history = [] + images = [] + for row in tqdm(dataset): + target_action_index = row['target_action_index'] + row = preprocess_row(row) + query = row['query'] + if target_action_index == '0': + if history: + conversations.append({'messages': history_to_messages(history), 'images': images, 'tools': tools}) + images = [] + history = [] + query = query + '\n' + row['confirmed_task'] + history.append([query, row['response']]) + images.append(row['screenshot']) + + if history: + conversations.append({'messages': history_to_messages(history), 'images': images, 'tools': tools}) + + return HfDataset.from_list(conversations) + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/Multimodal-Mind2Web', + hf_dataset_id='osunlp/Multimodal-Mind2Web', + preprocess_func=preprocess_mind2web, + tags=['agent', 'multi-modal'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/M3IT', + subsets=[ + 'coco', 'vqa-v2', 'shapes', 'shapes-rephrased', 'coco-goi-rephrased', 'snli-ve', 'snli-ve-rephrased', + 'okvqa', 'a-okvqa', 'viquae', 'textcap', 'docvqa', 'science-qa', 'imagenet', 'imagenet-open-ended', + 'imagenet-rephrased', 'coco-goi', 'clevr', 'clevr-rephrased', 'nlvr', 'coco-itm', 'coco-itm-rephrased', + 'vsr', 'vsr-rephrased', 'mocheg', 'mocheg-rephrased', 'coco-text', 'fm-iqa', 'activitynet-qa', 'msrvtt', + 'ss', 'coco-cn', 'refcoco', 'refcoco-rephrased', 'multi30k', 'image-paragraph-captioning', 'visual-dialog', + 'visual-dialog-rephrased', 'iqa', 'vcr', 'visual-mrc', 'ivqa', 'msrvtt-qa', 'msvd-qa', 'gqa', 'text-vqa', + 'ocr-vqa', 'st-vqa', 'flickr8k-cn' + ], + preprocess_func=ResponsePreprocessor(columns={ + 'instruction': 'system', + 'inputs': 'query', + 'image_base64_str': 'images', + 'outputs': 'response' + }), + split=['train'], + huge_dataset=True, + tags=['chat', 'multi-modal', 'vision'])) + + +class ShareGPT4VPreprocessor(MessagesPreprocessor): + + def prepare_dataset(self, dataset): + split = ['ShareGPT4V', 'ShareGPT4V-PT'] if dataset.config_name is None else dataset.config_name + IMAGE_DATASET_REQUIREMENTS = { + 'ShareGPT4V': ['coco', 'sam', 'llava', 'wikiart', 'share_textvqa', 'web-celebrity', 'web-landmark'], + 'ShareGPT4V-PT': ['coco', 'sam', 'llava'] + } + + if isinstance(split, str): + split = [split] + self.all_folders = {} + for sp in split: + for media_type in IMAGE_DATASET_REQUIREMENTS[sp]: + self.all_folders[media_type] = MediaResource.download(media_type) + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + image = row['image'] + row.update(super().preprocess(row)) + if 'coco/' in image: + image = os.path.join(self.all_folders['coco'], image.replace('coco/', '')) + elif 'sam/' in image: + image = os.path.join(self.all_folders['sam'], image.replace('sam/images/', '')) + elif 'llava/' in image: + image = os.path.join(self.all_folders['llava'], image.replace('llava/llava_pretrain/images/', '')) + elif 'wikiart/' in image: + image = os.path.join(self.all_folders['wikiart'], image.replace('wikiart/images/', 'data/wikiart/images/')) + elif 'share_textvqa/' in image: + image = os.path.join(self.all_folders['share_textvqa'], + image.replace('share_textvqa/images/', 'data/share_textvqa/images/')) + elif 'web-celebrity/' in image: + image = os.path.join(self.all_folders['web-celebrity'], + image.replace('web-celebrity/images/', 'data/web-celebrity/images/')) + elif 'web-landmark/' in image: + image = os.path.join(self.all_folders['web-landmark'], + image.replace('web-landmark/images/', 'data/web-landmark/images/')) + if os.path.exists(image): + row['images'] = image + else: + return + return row + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/ShareGPT4V', + subsets=['ShareGPT4V', 'ShareGPT4V-PT'], + preprocess_func=ShareGPT4VPreprocessor(), + huge_dataset=True, + tags=['chat', 'multi-modal', 'vision'])) + + +class TextCapsPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + row['query'] = 'What is the caption of this image?' + if not os.path.exists(row['images']['path']): + return None + return super().preprocess(row) + + +class TextCapsEmbPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if not os.path.exists(row['images']['path']): + return None + return { + 'messages': [{ + 'role': 'user', + 'content': '' + }], + 'positive_messages': [[{ + 'role': 'user', + 'content': row['response'][0] + }]], + 'images': row['images'] + } + + +class TextCapsReRankPreprocessor(RowPreprocessor): + + def __init__(self, + *, + columns: Optional[Dict[str, str]] = None, + negatives_per_sample: Optional[int] = None, + **kwargs): + super().__init__(columns=columns, **kwargs) + self._responses_pool: Optional[List[str]] = None + # Keep default aligned with MAX_NEGATIVE_SAMPLES used in collator if not provided + self.negatives_per_sample: int = int(os.environ.get( + 'MAX_NEGATIVE_SAMPLES', 7)) if negatives_per_sample is None else negatives_per_sample + + def prepare_dataset(self, dataset): + # Build a pool of responses from the dataset for negative sampling + try: + # Access full column; works for map-style datasets + responses = dataset['response'] if 'response' in dataset.features else None + except Exception: + responses = None + + pool: List[str] = [] + if responses is not None: + for resp in responses: + if isinstance(resp, (list, tuple)): + for s in resp: + if isinstance(s, str): + pool.append(s) + elif isinstance(resp, str): + pool.append(resp) + self._responses_pool = pool if pool else None + return dataset + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if not os.path.exists(row['images']['path']): + return None + positive = row['response'][0] + negatives: List[str] = [] + if self._responses_pool: + candidates = [s for s in self._responses_pool if isinstance(s, str) and s not in row['response']] + if candidates: + k = min(self.negatives_per_sample, len(candidates)) + # Use numpy RandomState from base class for deterministic sampling + idxs = self.random_state.choice(len(candidates), size=k, replace=False).tolist() + negatives = [candidates[i] for i in idxs] + return { + 'messages': [{ + 'role': 'user', + 'content': '' + }], + 'positive_messages': [[{ + 'role': 'user', + 'content': positive + }]], + 'negative_messages': [[{ + 'role': 'user', + 'content': n + }] for n in negatives], + 'images': row['images'] + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/TextCaps', + hf_dataset_id='HuggingFaceM4/TextCaps', + subsets=[ + SubsetDataset( + name='default', + preprocess_func=TextCapsPreprocessor(columns={'reference_strs': 'response'}), + split=['train', 'validation'], + ), + SubsetDataset( + name='emb', + preprocess_func=TextCapsEmbPreprocessor(columns={'reference_strs': 'response'}), + split=['train', 'validation'], + ), + SubsetDataset( + name='rerank', + preprocess_func=TextCapsReRankPreprocessor(columns={'reference_strs': 'response'}), + split=['train', 'validation'], + ), + ], + huge_dataset=True, + tags=['multi-modal', 'en', 'caption', 'quality'])) + + +class RefCOCOPreprocessor(ResponsePreprocessor, GroundingMixin): + task_type = 'caption' + + def __init__(self, task_type, **kwargs): + self.task_type = task_type + super().__init__(**kwargs) + + def prepare_dataset(self, dataset): + self.cache_dir = MediaResource.download( + 'https://www.modelscope.cn/api/v1/datasets/we_dont_produce_water/' + 'coco_res/repo?Revision=master&FilePath=coco_2014.zip', 'coco2014') + return dataset + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + caption = row['captions'][0] + bbox = row['bbox'] + image_path = os.path.join(self.cache_dir, row['image_path'].replace('coco/train2014', 'train2014')) + if not os.path.exists(image_path): + return + + for i in range(len(bbox)): + bbox[i] = round(float(bbox[i])) + res = {} + + objects = { + 'ref': [caption], + 'bbox': [bbox], + } + res['query'], res['response'] = self.construct_grounding_prompt() + res['images'] = [image_path] + res['objects'] = objects + return super().preprocess(res) + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/refcoco', + hf_dataset_id='jxu124/refcoco', + subsets=[ + SubsetDataset( + name='caption', + preprocess_func=RefCOCOPreprocessor('caption'), + ), + SubsetDataset( + name='grounding', + preprocess_func=RefCOCOPreprocessor('grounding'), + ) + ], + split=['train', 'validation'], + tags=['multi-modal', 'en', 'grounding'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/refcocog', + hf_dataset_id='jxu124/refcocog', + subsets=[ + SubsetDataset( + name='caption', + preprocess_func=RefCOCOPreprocessor('caption'), + ), + SubsetDataset( + name='grounding', + preprocess_func=RefCOCOPreprocessor('grounding'), + ) + ], + split=['train', 'validation'], + tags=['multi-modal', 'en', 'grounding'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/lnqa', + hf_dataset_id='vikhyatk/lnqa', + preprocess_func=MessagesPreprocessor(user_role='question', assistant_role='answer'), + split=['train', 'validation'], + huge_dataset=True, + tags=['multi-modal', 'en', 'ocr-vqa', 'quality'])) + + +class LLaVAInstructPreprocessor(MessagesPreprocessor): + + def prepare_dataset(self, dataset): + self.all_folders = {} + for media_type in ['coco', 'gqa', 'ocr_vqa', 'textvqa', 'VG_100K', 'VG_100K_2']: + self.all_folders[media_type] = MediaResource.download(media_type) + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + image = row['images'] + if 'coco/' in image: + image = os.path.join(self.all_folders['coco'], image.replace('coco/', '')) + elif 'gqa/' in image: + image = os.path.join(self.all_folders['gqa'], image.replace('gqa/', '')) + elif 'ocr_vqa/' in image: + image = os.path.join(self.all_folders['ocr_vqa'], image) + elif 'textvqa/' in image: + image = os.path.join(self.all_folders['textvqa'], image.replace('textvqa/', '')) + elif 'VG_100K/' in image: + image = os.path.join(self.all_folders['VG_100K'], image.replace('vg/', '')) + elif 'VG_100K_2/' in image: + image = os.path.join(self.all_folders['VG_100K_2'], image.replace('vg/', '')) + if os.path.exists(image): + row['images'] = image + else: + return + + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/LLaVA-Instruct-150K', + ms_revision='d5db3806e395c60496630a206c336932e85a2d00', + preprocess_func=LLaVAInstructPreprocessor(), + split=['train'], + tags=['chat', 'multi-modal', 'vision'])) + + +class LLaVAPretrainPreprocessor(MessagesPreprocessor): + + def prepare_dataset(self, dataset): + if not use_hf_hub(): + url = ('https://www.modelscope.cn/api/v1/datasets/AI-ModelScope/LLaVA-Pretrain/repo?' + 'Revision=master&FilePath=images.zip') + else: + url = f'{get_hf_endpoint()}/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip' + self.media_dir = MediaResource.download( + url, + # noqa + 'llava_pretrain') + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + row.update(super().preprocess(row)) + if row['image']: + file_path = os.path.join(self.media_dir, row['image']) + if os.path.exists(file_path): + return {'images': file_path} + else: + return + else: + return + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/LLaVA-Pretrain', + ms_revision='e3a3f0bfaad05e90e46745152a32bf944e0f4a63', + hf_dataset_id='liuhaotian/LLaVA-Pretrain', + preprocess_func=LLaVAPretrainPreprocessor(), + huge_dataset=True, + tags=['chat', 'multi-modal', 'quality'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/MideficsDataset', + hf_dataset_id='WinterSchool/MideficsDataset', + preprocess_func=MessagesPreprocessor(inner_key='data', user_role='question', assistant_role='answer'), + tags=['medical', 'en', 'vqa'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/OK-VQA_train', + hf_dataset_id='Multimodal-Fatima/OK-VQA_train', + preprocess_func=ResponsePreprocessor(), + tags=['multi-modal', 'en', 'vqa', 'quality'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/A-OKVQA', + hf_dataset_id='HuggingFaceM4/A-OKVQA', + split=['train', 'validation'], + preprocess_func=ResponsePreprocessor(columns={'rationales': 'response'}), + tags=['multi-modal', 'en', 'vqa', 'quality'])) + + +class OcrvqaPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + idx = np.random.choice(range(len(row['questions']))) + query = row['questions'][idx] + response = row['answers'][idx] + return { + 'messages': [{ + 'role': 'user', + 'content': query + }, { + 'role': 'assistant', + 'content': response + }], + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/OCR-VQA', + hf_dataset_id='howard-hou/OCR-VQA', + split=['train', 'validation'], + preprocess_func=OcrvqaPreprocessor(), + tags=['multi-modal', 'en', 'ocr-vqa'])) + + +class ScienceQAPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + query = row['question'] + response = row['choices'][row['answer']] + solution = row['solution'] + response = f'{solution}\nSo the final answer is: {response}' + return {'messages': [{'role': 'user', 'content': query}, {'role': 'assistant', 'content': response}]} + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/ScienceQA', + hf_dataset_id='derek-thomas/ScienceQA', + split=['train', 'validation'], + preprocess_func=ScienceQAPreprocessor(), + tags=['multi-modal', 'science', 'vqa', 'quality'])) + + +class GritPreprocessor(RowPreprocessor, GroundingMixin): + + def __init__(self, task_type, **kwargs): + self.task_type = task_type + super().__init__(**kwargs) + + @staticmethod + def has_overlap(start_ends): + for i in range(1, len(start_ends)): + if start_ends[i][0] < start_ends[i - 1][1]: + return True + return False + + @staticmethod + def replace_intervals_with_tags(response, start_ends): + result = [] + last_end = 0 + for start, end in start_ends: + result.append(response[int(last_end):int(start)]) + result.append('') + last_end = end + result.append(response[int(last_end):]) + return ''.join(result) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + images = row['images'] + caption = row['caption'] + ref_exps = row['ref_exps'] + objects = {'ref': [], 'bbox': [], 'bbox_type': 'norm1'} + start_end_pairs = [] + for ref_exp in ref_exps: + start = ref_exp[0] + end = ref_exp[1] + # conf = ref_exp[6] TODO filter low confidence rows? + start_end_pairs.append(ref_exp[0:2]) + + object_part = caption[int(start):int(end)] + objects['ref'].append(object_part) + objects['bbox'].append(ref_exp[2:6]) + + start_end_pairs.sort(key=lambda x: (x[0], x[1])) + if self.has_overlap(start_end_pairs) or not ref_exps: + return + + if self.task_type in ('grounding', 'caption'): + query, response = self.construct_grounding_prompt() + else: + query = 'what is the proper caption of this image?' + response = caption + return { + 'messages': [{ + 'role': 'user', + 'content': query + }, { + 'role': 'assistant', + 'content': response + }], + 'images': images, + 'objects': objects + } + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/GRIT', + hf_dataset_id='zzliang/GRIT', + subsets=[ + SubsetDataset( + name='caption', + preprocess_func=GritPreprocessor('caption', columns={'url': 'images'}), + ), + SubsetDataset( + name='grounding', + preprocess_func=GritPreprocessor('grounding', columns={'url': 'images'}), + ), + SubsetDataset( + name='vqa', + preprocess_func=GritPreprocessor('vqa', columns={'url': 'images'}), + ) + ], + huge_dataset=True, + tags=['multi-modal', 'en', 'caption-grounding', 'vqa', 'quality'])) + + +class GQAPreprocessor(RowPreprocessor): + + def prepare_dataset(self, dataset): + self.local_cache = MediaResource.download('gqa') + return super().prepare_dataset(dataset) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if os.path.join(self.local_cache, 'images', row['imageId'] + '.jpg'): + return { + 'messages': [{ + 'role': 'user', + 'content': row['question'] + }, { + 'role': 'assistant', + 'content': row['fullAnswer'] + }], + 'images': + os.path.join(self.local_cache, 'images', row['imageId'] + '.jpg'), + } + else: + return + + +register_dataset( + DatasetMeta( + hf_dataset_id='lmms-lab/GQA', + split=['train_all_instructions'], + preprocess_func=GQAPreprocessor(), + huge_dataset=True, + tags=['multi-modal', 'en', 'vqa', 'quality'])) + + +class CocoPreprocessor(ResponsePreprocessor): + category = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', + 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', + 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', + 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' + ] + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + row['query'] = 'Task: Object Detection' + objects = row['objects'] + objects['ref'] = [self.category[c] for c in objects['category']] + row['response'] = '\n'.join([''] * len(objects['ref'])) + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/coco', + hf_dataset_id='detection-datasets/coco', + preprocess_func=CocoPreprocessor(), + huge_dataset=True, + tags=['multi-modal', 'en', 'vqa', 'quality'])) + + +class LLaVAMixSFTPreprocessor(RowPreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + messages = row['messages'] + rounds = [] + for msg in messages: + role = msg['role'] + content = msg['content'] + text = '' + for index in content: + if index['type'] == 'text': + text += index['text'] + elif index['type'] == 'image': + text += '' + + rounds.append({'role': role, 'content': text}) + + return {'messages': rounds} + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/llava-instruct-mix-vsft', + hf_dataset_id='HuggingFaceH4/llava-instruct-mix-vsft', + split=['test'], + preprocess_func=LLaVAMixSFTPreprocessor(), + tags=['multi-modal', 'en', 'vqa', 'quality'])) + + +class LatexocrPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['query'] = 'Using LaTeX to perform OCR on the image.' + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/LaTeX_OCR', + hf_dataset_id='linxy/LaTeX_OCR', + subsets=['default', 'human_handwrite', 'human_handwrite_print', 'synthetic_handwrite', 'small'], + preprocess_func=LatexocrPreprocessor(), + split=['train', 'validation', 'test'], + tags=['chat', 'ocr', 'multi-modal', 'vision'], + )) + + +class CapchaImagesPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['query'] = 'recognize the content.' + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/captcha-images', + split=['train', 'validation'], + preprocess_func=CapchaImagesPreprocessor(columns={'solution': 'response'}), + tags=['chat', 'multi-modal', 'vision'])) + + +class ClevrPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + query = row.get('query', '') + query = (f'{query} Output the thinking process in and ' + 'final answer (number) in tags.') + row.update({'query': query}) + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='AI-ModelScope/clevr_cogen_a_train', + hf_dataset_id='leonardPKU/clevr_cogen_a_train', + preprocess_func=ClevrPreprocessor(), + tags=['qa', 'math', 'vision', 'grpo'])) + + +class Voc2007MultilabelPreprocessor(ResponsePreprocessor): + CLASS_NAME = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', + 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['query'] = f'多标签分类,类别包括:{list(self.CLASS_NAME)}' + row['label'] = [i for i, x in enumerate(row['npy']) if x == 1] + return super().preprocess(row) + + +register_dataset( + DatasetMeta( + ms_dataset_id='clip-benchmark/wds_voc2007_multilabel', + hf_dataset_id='clip-benchmark/wds_voc2007_multilabel', + preprocess_func=Voc2007MultilabelPreprocessor(columns={'webp': 'images'}), + tags=['multilabel', 'multi-modal'], + )) diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index 83fc4660..122e7e1e 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -5,8 +5,7 @@ from ..utils import exists -kernel_mapping = { - +torch_kernel_mapping = { } @@ -37,8 +36,8 @@ def apply_kernel(module: Any, def apply_kernel_torch(module: Any, kernel: "Optional[Union[str, Callable, 'torch.nn.Module']]", target_modules: Union[str, List[str]]): - if kernel in kernel_mapping: - kernel = kernel_mapping[kernel] + if kernel in torch_kernel_mapping: + kernel = torch_kernel_mapping[kernel] kernel_fn = kernel import torch diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index e69de29b..6fecd531 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -0,0 +1,25 @@ +from .mse import MSELoss +from .contrastive_loss import ContrastiveLoss +from .online_contrastive_loss import OnlineContrastiveLoss +from .infonce import InfoNCELoss +from .cross_entropy import CrossEntropyLoss +from .chunked_cross_entropy import ChunkedCrossEntropyLoss, ChunkedCrossEntropyLossFunc +from .cosine_similarity import CosineSimilarityLoss +from .generative_reranker import GenerativeRerankerLoss +from .reranker import RerankerLoss +from .listwise_reranker import ListwiseRerankerLoss +from .listwise_generative_reranker import ListwiseGenerativeRerankerLoss + +torch_loss_mapping = { + 'mse': MSELoss, + 'contrastive': ContrastiveLoss, + 'online_contrastive': OnlineContrastiveLoss, + 'infonce': InfoNCELoss, + 'cross_entropy': CrossEntropyLoss, + 'chunked_cross_entropy': ChunkedCrossEntropyLoss, + 'cosine_similarity': CosineSimilarityLoss, + 'generative_reranker': GenerativeRerankerLoss, + 'reranker': RerankerLoss, + 'listwise_reranker': ListwiseRerankerLoss, + 'listwise_generative_reranker': ListwiseGenerativeRerankerLoss, +} \ No newline at end of file diff --git a/src/twinkle/loss/base.py b/src/twinkle/loss/base.py new file mode 100644 index 00000000..434125ef --- /dev/null +++ b/src/twinkle/loss/base.py @@ -0,0 +1,6 @@ + + +class Loss: + + def __call__(self, *args, **kwargs): + ... \ No newline at end of file diff --git a/src/twinkle/loss/chunked_cross_entropy.py b/src/twinkle/loss/chunked_cross_entropy.py new file mode 100644 index 00000000..8f001bfc --- /dev/null +++ b/src/twinkle/loss/chunked_cross_entropy.py @@ -0,0 +1,55 @@ +import math +from typing import Any + +from src.twinkle.loss.base import Loss +import torch + +class ChunkedCrossEntropyLossFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits, labels, chunk_size): + ctx.save_for_backward(logits, labels) + ctx.chunk_size = chunk_size + + losses = [] + for i in range(math.ceil(logits.shape[0] / chunk_size)): + l_start = i * chunk_size + l_end = min((i + 1) * chunk_size, logits.shape[0]) + logits_chunk = logits[l_start:l_end] + labels_chunk = labels[l_start:l_end] + loss_fct = torch.nn.CrossEntropyLoss(reduction='none') + loss_chunk = loss_fct(logits_chunk, labels_chunk) + losses.append(loss_chunk) + del logits_chunk + del labels_chunk + all_losses = torch.cat(losses) + return all_losses + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any): + logits, labels = ctx.saved_tensors + chunk_size = ctx.chunk_size + + for i in range(math.ceil(logits.shape[0] / chunk_size)): + l_start = i * chunk_size + l_end = min((i + 1) * chunk_size, logits.shape[0]) + logits_chunk = logits[l_start:l_end].detach().requires_grad_(True) + labels_chunk = labels[l_start:l_end] + loss_fct = torch.nn.CrossEntropyLoss(reduction='none') + with torch.enable_grad(): + loss_chunk = loss_fct(logits_chunk, labels_chunk) + grad_output_chunk = grad_outputs[0][l_start:l_end] + _loss_chunk = (loss_chunk * grad_output_chunk).sum() + grad_chunk = torch.autograd.grad(_loss_chunk, logits_chunk, retain_graph=False)[0] + logits[l_start:l_end] = grad_chunk + + return logits, None, None + + +class ChunkedCrossEntropyLoss(Loss): + + def __init__(self, chunk_size): + self.chunk_size = chunk_size + + def __call__(self, logits, labels, **kwargs): + return ChunkedCrossEntropyLossFunc.apply(logits, labels, self.chunk_size) \ No newline at end of file diff --git a/src/twinkle/loss/contrastive_loss.py b/src/twinkle/loss/contrastive_loss.py new file mode 100644 index 00000000..eb644bdb --- /dev/null +++ b/src/twinkle/loss/contrastive_loss.py @@ -0,0 +1,23 @@ +from enum import Enum +from .base import Loss +import torch + + +# Code borrowed from sentence_transformers +class SiameseDistanceMetric(Enum): + """The metric for the contrastive loss""" + + EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa + MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa + COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa + + +class ContrastiveLoss(Loss): + + def __call__(self, sentence1, sentence2, labels, **kwargs): + distance_metric = SiameseDistanceMetric.COSINE_DISTANCE + distances = distance_metric(sentence1, sentence2) + margin = 0.5 + labels = labels.to(sentence1.dtype) + losses = 0.5 * (labels * distances.pow(2) + (1 - labels) * torch.nn.functional.relu(margin - distances).pow(2)) + return losses.mean() \ No newline at end of file diff --git a/src/twinkle/loss/cosine_similarity.py b/src/twinkle/loss/cosine_similarity.py new file mode 100644 index 00000000..9facd757 --- /dev/null +++ b/src/twinkle/loss/cosine_similarity.py @@ -0,0 +1,11 @@ +from src.twinkle.loss.base import Loss +import torch + + +class CosineSimilarityLoss(Loss): + + def __call__(self, sentence1, sentence2, labels, **kwargs): + cos_score_transformation = torch.nn.Identity() + loss_fct = torch.MSELoss() + output = cos_score_transformation(torch.cosine_similarity(sentence1, sentence2)) + return loss_fct(output, labels.to(output.dtype).view(-1)) \ No newline at end of file diff --git a/src/twinkle/loss/cross_entropy.py b/src/twinkle/loss/cross_entropy.py new file mode 100644 index 00000000..6db8b558 --- /dev/null +++ b/src/twinkle/loss/cross_entropy.py @@ -0,0 +1,8 @@ +from src.twinkle.loss.base import Loss +import torch + + +class CrossEntropyLoss(Loss): + + def __call__(self, logits, labels, **kwargs): + return torch.nn.CrossEntropyLoss()(logits, labels) \ No newline at end of file diff --git a/src/twinkle/loss/generative_reranker.py b/src/twinkle/loss/generative_reranker.py new file mode 100644 index 00000000..fa172b6d --- /dev/null +++ b/src/twinkle/loss/generative_reranker.py @@ -0,0 +1,56 @@ +from src.twinkle.loss.base import Loss +import torch + + +class GenerativeRerankerLoss(Loss): + + def __init__(self, tokenizer, positive_token='yes', negative_token='no'): + self.tokenizer = tokenizer + self.positive_token = positive_token + self.negative_token = negative_token + + def __call__(self, logits, labels, last_valid_indices, **kwargs): + """ + Generative reranker loss function. + + This loss function is designed for generative rerankers that use token probabilities + (e.g., "yes"/"no") to determine relevance scores. It only computes loss on the + last token position for specific tokens. + + Args: + outputs: Model outputs containing logits + labels: Binary labels (0/1) for irrelevant/relevant pairs + last_valid_indices: The last valid indices to compute loss + + Returns: + torch.Tensor: Cross entropy loss for yes/no classification + """ + # Get token IDs for positive and negative tokens + # Default to "yes"/"no", but can be configured via environment variables + try: + positive_token_id = self.tokenizer.convert_tokens_to_ids(self.positive_token) + negative_token_id = self.tokenizer.convert_tokens_to_ids(self.negative_token) + except Exception as e: + raise ValueError(f"Failed to convert tokens '{self.positive_token}'/'{self.negative_token}' to IDs. " + f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}') + + # Extract logits at the last valid (non-padding) token position for each sample + batch_size = logits.shape[0] + batch_indices = torch.arange(batch_size, device=logits.device) + last_valid_logits = logits[batch_indices, last_valid_indices, :] + + positive_logits = last_valid_logits[:, positive_token_id] # [batch_size] + negative_logits = last_valid_logits[:, negative_token_id] # [batch_size] + + # Stack to create binary classification logits + # Shape: [batch_size, 2] where dim=1 represents [negative, positive] + binary_logits = torch.stack([negative_logits, positive_logits], dim=1) + + # Convert labels to the correct device and type + binary_labels = labels.to(binary_logits.device).long() + + # Compute cross entropy loss + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(binary_logits, binary_labels) + + return loss \ No newline at end of file diff --git a/src/twinkle/loss/infonce.py b/src/twinkle/loss/infonce.py new file mode 100644 index 00000000..25f19aff --- /dev/null +++ b/src/twinkle/loss/infonce.py @@ -0,0 +1,263 @@ +from .base import Loss +from ..utils import torch_util +import numpy as np +import torch + + +class InfoNCELoss(Loss): + + def __init__(self, + temperature:float=0.1, + cross_batch:bool=True, + hard_negatives:int=None, + mask_fake_negative:bool=False, + fake_negative_margin:float=0.1, + include_qq:bool=False, + include_dd:bool=False, + ): + """InfoNCE loss + + Args: + temperature: The temperature before calculating similarity, default `0.1`. The loss will be smoother if the temperature is smaller. + cross_batch: If to calculate similarity across all samples in one batch. + hard_negatives: Padding or cutting off negative samples to a fixed length. This will help to speed up calculation. + mask_fake_negative: If to mask negative samples if the similarities are big enough + fake_negative_margin: Used with `mask_fake_negative`, If the negative similarities are big than pos-similarities+fake_negative_margin, the sample will be masked. + include_qq: Whether to include query-query similarities. + include_dd: Whether to include document-document similarities. + """ + self.temperature = temperature + self.cross_batch = cross_batch + self.hard_negatives = hard_negatives + self.mask_fake_negative = mask_fake_negative + self.fake_negative_margin = fake_negative_margin + self.include_qq = include_qq + self.include_dd = include_dd + + @staticmethod + def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None): + split_indices = torch.nonzero(labels, as_tuple=False).squeeze().tolist() + if isinstance(split_indices, int): + split_indices = [split_indices] + split_indices.append(len(labels)) + split_indices = np.array(split_indices) + np.array(list(range(len(split_indices)))) + split_tensors = [] + + for i in range(len(split_indices) - 1): + start = split_indices[i] + end = split_indices[i + 1] + split_part = sentences[start:end] + if hard_negatives is not None: + negatives = len(split_part) - 2 + assert negatives > 0 + if negatives > hard_negatives: + split_part = split_part[:hard_negatives + 2] + elif negatives < hard_negatives: + selected = np.random.choice(list(range(negatives)), size=hard_negatives - negatives, replace=True) + selected += 1 # skip positive + split_part = torch.cat((split_part, split_part[selected]), dim=0) + split_tensors.append(split_part) + return split_tensors + + def _gather_data(self, logits, labels): + # gather all the logits and labels across the gpus when calculate loss across all batches of all gpus + + from accelerate.utils import gather_object + rank = torch_util.get_rank() + all_preds = gather_object(logits.unsqueeze(0)) + labels = gather_object(labels) + # override the gathered one + all_preds[rank] = logits + for idx in range(len(all_preds)): + if idx == rank: + continue + # we don't calculate grad from other gpus + all_preds[idx] = all_preds[idx].detach().to(logits.device) + logits = torch.cat(all_preds, dim=0) + labels = [tensor.to(logits.device) for tensor in labels] + labels = torch.stack(labels, dim=0) + return logits, labels + + def _calculate_batched_loss(self, split_tensors): + # negative numbers are equal + # [B, neg+2, D] + sentences = torch.stack(split_tensors, dim=0) + # [B, 1, D] * [B, neg+1, D] + similarity_matrix = torch.matmul(sentences[:, 0:1], sentences[:, 1:].transpose(1, 2)) / self.temperature + # The positive one is the first element + labels = torch.zeros(len(split_tensors), dtype=torch.int64).to(sentences.device) + return torch.nn.CrossEntropyLoss()(similarity_matrix.squeeze(1), labels) + + def _calculate_looped_loss(self, split_tensors): + loss = 0 + # the negative numbers may be different, use for loop + for tensor in split_tensors: + # [D] * [neg+1, D] + similarity_matrix = torch.matmul(tensor[0], tensor[1:].T) / self.temperature + # The positive one is the first element + labels = torch.tensor(0).to(tensor.device) + loss += torch.nn.CrossEntropyLoss()(similarity_matrix, labels) + # avg between all batches in one gpu + loss /= len(split_tensors) + return loss + + def _calculate_batched_loss_across_device(self, split_tensors): + # [B, neg+2, D] + sentences = torch.stack(split_tensors, dim=0) + # base q->d similarities (includes own positive and all in-batch documents) + queries = sentences[:, 0].squeeze(1) # [B, D] + docs_all = sentences[:, 1:].reshape(-1, sentences.size(2)) # [B*(neg+1), D] + qd_matrix = torch.matmul(queries, docs_all.T) # [B, B*(neg+1)] + # target indices: start of each group's document block (its positive) + labels = torch.tensor(range(0, + sentences.size(0) * (sentences.size(1) - 1), + sentences.size(1) - 1)).view(-1).to(sentences.device) + + logits_list = [qd_matrix] + + if self.include_qq: + # q->q similarities; exclude self via -inf on diagonal to avoid accidental positives + qq_matrix = torch.matmul(queries, queries.T) # [B, B] + qq_matrix = qq_matrix.clone() + qq_matrix.fill_diagonal_(float('-inf')) + logits_list.append(qq_matrix) + + if self.include_dd: + # d+ -> d (doc-doc) similarities; exclude self-positive column per row + pos_docs = sentences[:, 1].squeeze(1) # [B, D] + dd_matrix = torch.matmul(pos_docs, docs_all.T) # [B, B*(neg+1)] + # mask self positive per row: column index = row_idx * (neg+1) + block = sentences.size(1) - 1 # (neg+1) + if block > 0: + row_idx = torch.arange(dd_matrix.size(0), device=dd_matrix.device) + col_idx = row_idx * block + dd_matrix[row_idx, col_idx] = float('-inf') + logits_list.append(dd_matrix) + + if self.mask_fake_negative: + # thresholds derived from positive q->d scores per row + row_idx = torch.arange(qd_matrix.size(0), device=qd_matrix.device) + pos_scores = qd_matrix[row_idx, labels] + thresholds = pos_scores.view(-1, 1).detach() + self.fake_neg_margin + + # qd block mask + qd_block = qd_matrix.clone() + qd_mask = qd_block > thresholds + qd_block[qd_mask] = float('-inf') + + components = [qd_block] + + # qq block mask (if present) + if self.include_qq: + qq_block = qq_matrix.clone() + qq_mask = qq_block > thresholds + qq_block[qq_mask] = float('-inf') + # diagonal already masked unconditionally at construction time + components.append(qq_block) + + # dd block (if present): self-positive column already masked unconditionally + if self.include_dd: + # align with Qwen3-Embedding, no threshold masking for d-d + components.append(dd_matrix) + + similarity_matrix = torch.cat(components, dim=1) + else: + # concatenate all components without masking + similarity_matrix = torch.cat(logits_list, dim=1) + # temperature scaling and CE + similarity_matrix = similarity_matrix / self.temperature + return torch.nn.CrossEntropyLoss()(similarity_matrix, labels) + + def _calculate_looped_loss_across_device(self, split_tensors): + all_tensors = [] + loss = 0 + for tensor in split_tensors: + all_tensors.append(tensor[1:]) + # cat all neg+1 tensors + sentences = torch.cat(all_tensors, dim=0) + # prepare query anchors list if q-q is included + if self.include_qq: + queries_all = torch.stack([t[0] for t in split_tensors], dim=0) # [B, D] + length = 0 + for idx, tensor in enumerate(split_tensors): + # [D] * [B*(neg+1), D], neg numbers are different + qd_vec = torch.matmul(tensor[0], sentences.T) + target = torch.tensor(length).to(tensor.device) + logits_parts = [] + + # compute threshold from positive q->d score + threshold = (qd_vec[target].detach() + self.fake_neg_margin) + + # qd part with masking + if self.mask_fake_negative: + qd_masked = torch.where(qd_vec > threshold, torch.tensor(float('-inf'), device=qd_vec.device), + qd_vec) + else: + qd_masked = qd_vec + logits_parts.append(qd_masked) + + # qq part + if self.include_qq: + qq_vec = torch.matmul(tensor[0], queries_all.T) # [B] + # exclude self + qq_vec = qq_vec.clone() + qq_vec[idx] = float('-inf') + if self.mask_fake_negative: + qq_vec = torch.where(qq_vec > threshold, torch.tensor(float('-inf'), device=qq_vec.device), + qq_vec) + logits_parts.append(qq_vec) + + # dd part + if self.include_dd: + dd_vec = torch.matmul(tensor[1], sentences.T) # [B*(neg+1)] + # mask self positive column for this row only (no threshold masking for d-d) + block = split_tensors[idx].size(0) - 1 # (neg+1) for this group + dd_vec[length] = float('-inf') + logits_parts.append(dd_vec) + + logits_row = torch.cat(logits_parts, dim=-1) + logits_row = logits_row / self.temperature + loss += torch.nn.CrossEntropyLoss()(logits_row.unsqueeze(0), target.unsqueeze(0)) + # next positive is neg+1 + length += tensor.size(0) - 1 + loss /= len(split_tensors) + return loss + + def __call__(self, logits: 'torch.Tensor', labels: 'torch.Tensor', **kwargs): + """Calculate loss + + Args: + logits: shape [Length * Hidden_size], Length=Repeat of (anchor(1)+positive(1)+negatives(n)) + labels: shape [Length], Length= Repeat of (positive(1)+negatives(n)) + For example: + logits = [[0.1], [0.25], [0.15], ... [0.2], [0.3], [0.14]] + labels = [ 1, 0, 1, 0] + anchor positive negative anchor positive negative + + Returns: + The loss tensor + """ + world_size = torch_util.get_world_size() + if world_size > 1 and self.cross_batch: + logits, labels = self._gather_data(logits, labels) + + # split tensors into single sample + # Example: batch_size=2 with tensor anchor(1)+positive(1)+negatives(3) + anchor(1)+positive(1)+negatives(2) + # labels will be [1,0,0,0,1,0,0], meaning 1 positive, 3 negatives, 1 positive, 2 negatives + split_tensors = self._parse_multi_negative_sentences(logits, labels, self.hard_negatives) + can_batched = self.hard_negatives is not None + if self.hard_negatives is None and len(set([s.shape[0] for s in split_tensors])) == 1: + # all tensors have the same batch size + can_batched = True + if not self.cross_batch: + # only calculate loss inside one sample + if can_batched: + loss = self._calculate_batched_loss(split_tensors) + else: + loss = self._calculate_looped_loss(split_tensors) + else: + if can_batched: + loss = self._calculate_batched_loss_across_device(split_tensors) + else: + loss = self._calculate_looped_loss_across_device(split_tensors) + return loss diff --git a/src/twinkle/loss/listwise_generative_reranker.py b/src/twinkle/loss/listwise_generative_reranker.py new file mode 100644 index 00000000..da88f9e7 --- /dev/null +++ b/src/twinkle/loss/listwise_generative_reranker.py @@ -0,0 +1,114 @@ +from src.twinkle.loss.base import Loss +import torch + + +class ListwiseGenerativeRerankerLoss(Loss): + + def __init__(self, tokenizer, positive_token='yes', negative_token='no', temperature=1.0, min_group_size=2 + ): + self.tokenizer = tokenizer + self.positive_token = positive_token + self.negative_token = negative_token + self.temperature = temperature + self.min_group_size = min_group_size + + def __call__(self, logits, labels, last_valid_indices, **kwargs): + """ + List-wise generative reranker loss function. + + This loss function combines the generative reranker approach (using token probabilities) + with list-wise ranking. It groups samples by query based on the pattern where each group + consists of 1 positive document followed by n negative documents, then uses the + probabilities of specific tokens (e.g., "yes"/"no") to perform ranking within each group. + + Data format expected: + - labels: [1, 0, 0, 0, 1, 0, 0, ...] where 1 indicates positive, 0 indicates negative + - Each 1 is followed by its corresponding negative documents until the next 1 + + Environment variables for configuration: + - GENERATIVE_RERANKER_POSITIVE_TOKEN: Token for positive relevance (default: "yes") + - GENERATIVE_RERANKER_NEGATIVE_TOKEN: Token for negative relevance (default: "no") + - LISTWISE_RERANKER_TEMPERATURE: Temperature for softmax (default: 1.0) + - LISTWISE_RERANKER_MIN_GROUP_SIZE: Minimum group size to include (default: 2) + + Args: + outputs: Model outputs containing logits [batch_size, seq_len, vocab_size] + labels: Binary labels (1 for positive, 0 for negative) [batch_size] + loss_scale: Not used for listwise generative reranker + num_items_in_batch: Not used for listwise generative reranker + trainer: Trainer instance to access tokenizer + + Returns: + torch.Tensor: Cross entropy loss for ranking classification based on token probabilities + """ + # Get token IDs for positive and negative tokens + try: + positive_token_id = self.tokenizer.convert_tokens_to_ids(self.positive_token) + negative_token_id = self.tokenizer.convert_tokens_to_ids(self.negative_token) + except Exception as e: + raise ValueError(f"Failed to convert tokens '{self.positive_token}'/'{self.negative_token}' to IDs. " + f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}') + + # Extract logits at the last valid (non-padding) token position for each sample + batch_size = logits.shape[0] + batch_indices = torch.arange(batch_size, device=logits.device) + last_valid_logits = logits[batch_indices, last_valid_indices, :] + + positive_logits = last_valid_logits[:, positive_token_id] # [batch_size] + negative_logits = last_valid_logits[:, negative_token_id] # [batch_size] + + logits = torch.nn.functional.logsigmoid(positive_logits - negative_logits) + + # Find positive sample indices to determine group boundaries + positive_indices = torch.nonzero(labels == 1, as_tuple=False).squeeze(-1) + + if len(positive_indices) == 0: + # No positive samples in this batch, return zero loss + return torch.tensor(0.0, device=logits.device, requires_grad=True) + + # Ensure positive_indices is 1D + if positive_indices.dim() == 0: + positive_indices = positive_indices.unsqueeze(0) + + total_loss = 0.0 + num_groups = 0 + + for i, pos_idx in enumerate(positive_indices): + # Determine group boundaries + group_start = pos_idx.item() + + # Find the end of current group (start of next group or end of batch) + if i + 1 < len(positive_indices): + group_end = positive_indices[i + 1].item() + else: + group_end = len(labels) + + # Extract group relevance scores and labels + group_scores = logits[group_start:group_end] # [group_size] + group_labels = labels[group_start:group_end] # [group_size] + + # Skip groups that are too small + if len(group_scores) < self.min_group_size: + continue + + # Verify that the first sample in the group is positive + if group_labels[0] != 1: + continue # Skip malformed groups + + group_logits = group_scores / self.temperature + + # The positive document is always at index 0 within the group + target = torch.tensor(0, dtype=torch.long, device=logits.device) + + # Apply cross-entropy loss: positive document should have highest relevance score + loss_fct = torch.nn.CrossEntropyLoss() + group_loss = loss_fct(group_logits.unsqueeze(0), target.unsqueeze(0)) + + total_loss += group_loss + num_groups += 1 + + if num_groups == 0: + return torch.tensor(0.0, device=logits.device, requires_grad=True) + + # Return average loss across all groups + return total_loss / num_groups \ No newline at end of file diff --git a/src/twinkle/loss/listwise_reranker.py b/src/twinkle/loss/listwise_reranker.py new file mode 100644 index 00000000..ca6ee67f --- /dev/null +++ b/src/twinkle/loss/listwise_reranker.py @@ -0,0 +1,90 @@ +from src.twinkle.loss.base import Loss +import torch + + +class ListwiseRerankerLoss(Loss): + + def __init__(self, temperature=1.0, min_group_size=2): + self.temperature = temperature + self.min_group_size = min_group_size + + def __call__(self, logits, labels, **kwargs): + """ + List-wise reranker loss function. + + This loss function groups samples by query based on the pattern where each group + consists of 1 positive document followed by n negative documents. It treats the + ranking task as a classification problem within each group, using cross-entropy + loss to identify the positive document among all candidates. + + Data format expected: + - labels: [1, 0, 0, 0, 1, 0, 0, ...] where 1 indicates positive, 0 indicates negative + - Each 1 is followed by its corresponding negative documents until the next 1 + + Environment variables for configuration: + - LISTWISE_RERANKER_TEMPERATURE: Temperature for softmax (default: 1.0) + - LISTWISE_RERANKER_MIN_GROUP_SIZE: Minimum group size to include (default: 2) + + Args: + outputs: Model outputs containing logits [batch_size, 1] + labels: Binary labels (1 for positive, 0 for negative) [batch_size] + loss_scale: Not used for listwise reranker + num_items_in_batch: Not used for listwise reranker + + Returns: + torch.Tensor: Cross entropy loss for ranking classification + """ + # Find positive sample indices to determine group boundaries + positive_indices = torch.nonzero(labels == 1, as_tuple=False).squeeze(-1) + + if len(positive_indices) == 0: + # No positive samples in this batch, return zero loss + return torch.tensor(0.0, device=logits.device, requires_grad=True) + + # Ensure positive_indices is 1D + if positive_indices.dim() == 0: + positive_indices = positive_indices.unsqueeze(0) + + total_loss = 0.0 + num_groups = 0 + + for i, pos_idx in enumerate(positive_indices): + # Determine group boundaries + group_start = pos_idx.item() + + # Find the end of current group (start of next group or end of batch) + if i + 1 < len(positive_indices): + group_end = positive_indices[i + 1].item() + else: + group_end = len(labels) + + # Extract group logits and labels + group_logits = logits[group_start:group_end] # [group_size] + group_labels = labels[group_start:group_end] # [group_size] + + # Skip groups that are too small + if len(group_logits) < self.min_group_size: + continue + + # Verify that the first sample in the group is positive + if group_labels[0] != 1: + continue # Skip malformed groups + + # Apply temperature scaling for better training dynamics + scaled_logits = group_logits / self.temperature + + # The positive document is always at index 0 within the group + target = torch.tensor(0, dtype=torch.long, device=logits.device) + + # Apply cross-entropy loss: positive document should have highest score + loss_fct = torch.nn.CrossEntropyLoss() + group_loss = loss_fct(scaled_logits.unsqueeze(0), target.unsqueeze(0)) + + total_loss += group_loss + num_groups += 1 + + if num_groups == 0: + return torch.tensor(0.0, device=logits.device, requires_grad=True) + + # Return average loss across all groups + return total_loss / num_groups \ No newline at end of file diff --git a/src/twinkle/loss/mse.py b/src/twinkle/loss/mse.py new file mode 100644 index 00000000..378c80a5 --- /dev/null +++ b/src/twinkle/loss/mse.py @@ -0,0 +1,8 @@ +from src.twinkle.loss.base import Loss +import torch + + +class MSELoss(Loss): + + def __call__(self, preds, labels, **kwargs): + return torch.nn.MSELoss()(preds, labels) \ No newline at end of file diff --git a/src/twinkle/loss/online_contrastive_loss.py b/src/twinkle/loss/online_contrastive_loss.py new file mode 100644 index 00000000..5d78126a --- /dev/null +++ b/src/twinkle/loss/online_contrastive_loss.py @@ -0,0 +1,23 @@ +import torch + +from .base import Loss +from .contrastive_loss import SiameseDistanceMetric + + +class OnlineContrastiveLoss(Loss): + + def __call__(self, sentence1, sentence2, labels, **kwargs): + distance_metric = SiameseDistanceMetric.COSINE_DISTANCE + distance_matrix = distance_metric(sentence1, sentence2) + negs = distance_matrix[labels == 0] + poss = distance_matrix[labels == 1] + + # select hard positive and hard negative pairs + negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())] + positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())] + + positive_loss = positive_pairs.pow(2).sum() + margin = 0.5 + negative_loss = torch.nn.functional.relu(margin - negative_pairs).pow(2).sum() + loss = positive_loss + negative_loss + return loss \ No newline at end of file diff --git a/src/twinkle/loss/reranker.py b/src/twinkle/loss/reranker.py new file mode 100644 index 00000000..d8f0814b --- /dev/null +++ b/src/twinkle/loss/reranker.py @@ -0,0 +1,12 @@ +from src.twinkle.loss.base import Loss +import torch + + +class RerankerLoss(Loss): + + def __call__(self, logits, labels, **kwargs): + logits = logits.squeeze(1) + labels = labels.to(logits.dtype) + loss_fct = torch.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + return loss \ No newline at end of file diff --git a/src/twinkle/metric/__init__.py b/src/twinkle/metric/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle/patch/__init__.py b/src/twinkle/patch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twinkle/preprocessor/core.py b/src/twinkle/preprocessor/core.py new file mode 100644 index 00000000..f7c3dd3c --- /dev/null +++ b/src/twinkle/preprocessor/core.py @@ -0,0 +1,549 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import ast +import os +from collections import Counter +from contextlib import contextmanager +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +from datasets import Dataset as HfDataset +from datasets import Image +from datasets import IterableDataset as HfIterableDataset +from datasets import Sequence, Value + +DATASET_TYPE = Union[HfDataset, HfIterableDataset] + +_pair_keys = ['messages', 'images', 'videos', 'audios', 'tools', 'objects'] + + +class RowPreprocessor: + standard_keys = _pair_keys + list( + chain.from_iterable([f'{prefix}_{k}' for k in _pair_keys] + for prefix in ['rejected', 'positive', 'negative'])) + [ + 'rejected_response', + 'label', + 'channel', + 'margin', + ] + + def __init__(self, + *, + columns: Optional[Dict[str, str]] = None, + dataset_sample: Optional[int] = None, + random_state: Optional[Union[np.random.RandomState, int]] = 42, + traceback_limit: int = 10) -> None: + self.columns = columns or {} + self.origin_columns = self.columns.copy() # Higher priority and raise Error + self._version = 'v1' + images_keys = ['images', 'image'] + audios_keys = ['audios', 'audio'] + videos_keys = ['videos', 'video'] + for mm_type in ['images', 'audios', 'videos']: + keys = locals()[f'{mm_type}_keys'] + for key in keys: + self.columns[key] = mm_type + + self.traceback_limit = traceback_limit + self._traceback_counter = 0 + self.dataset_sample = dataset_sample + if not isinstance(random_state, np.random.RandomState): + random_state = np.random.RandomState(random_state) + self.random_state = random_state + + @staticmethod + def _check_messages(row: Dict[str, Any]) -> None: + if 'messages' not in row: + return + messages = row['messages'] + assert len(messages) > 0, f'messages: {messages}' + # fix swift/SlimOrca + for message in messages: + keys = set(message.keys()) - {'role', 'content', 'loss'} + for key in keys: + message.pop(key) + + for message in messages: + role, content = message['role'], message['content'] + # The terms 'tool' and 'tool_response' have the same meaning, ensuring compatibility. + assert role in {'system', 'user', 'tool_call', 'tool_response', 'tool', 'assistant'}, f'message: {message}' + assert content is not None, f'message: {message}' + + @staticmethod + def _cast_mm_data(row: Dict[str, Any]) -> None: + for key in ['images', 'rejected_images']: + images = row.get(key, None) + if images is None: + continue + + if isinstance(images, str) or (isinstance(images, list) and images and isinstance(images[0], str)): + if isinstance(images, str): + images = [images] + for i, image in enumerate(images): + images[i] = {'bytes': None, 'path': image} + row[key] = images + elif isinstance(images, dict): + row[key] = [images] + + for key in ['videos', 'audios']: + mm_data = row.get(key) + if mm_data is None: + continue + elif isinstance(mm_data, str): + row[key] = [mm_data] + + @staticmethod + def _check_rejected_response(row: Dict[str, Any]) -> None: + if 'rejected_response' in row: + messages = row['messages'] + rejected_response = row['rejected_response'] + if (rejected_response is None + or isinstance(rejected_response, str) and rejected_response == messages[-1]['content']): + raise ValueError(f'rejected_response: {rejected_response}') + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + raise NotImplementedError + + def prepare_dataset(self, dataset: DATASET_TYPE) -> DATASET_TYPE: + return dataset + + @staticmethod + def batched_to_rows(batched_row: Dict[str, Any]): + keys = list(batched_row.keys()) + batch_size = len(batched_row[keys[0]]) + return [{key: batched_row[key][i] for key in keys} for i in range(batch_size)] + + @staticmethod + def rows_to_batched(rows: List[Dict[str, Any]]): + batched = {} + for i, row in enumerate(rows): + for k, v in row.items(): + if k not in batched: + batched[k] = [None] * i + batched[k].append(v) + # Make all the lengths of v the same. + for k in set(batched.keys()) - set(row.keys()): + batched[k].append(None) + return batched + + @staticmethod + def _remove_prefix_keys(row, prefix: str): + for k in list(row.keys()): + if k.startswith(prefix): + new_k = k[len(prefix):] + new_v = row.pop(k) + if new_k not in row: + row[new_k] = new_v + + @staticmethod + def _check_objects(row): + objects = row.get('objects') + if objects is None: + return + new_objects = {} + # Ensure the order + for k in ['ref', 'bbox', 'bbox_type', 'image_id']: + if k in objects.keys(): + new_objects[k] = objects[k] + row['objects'] = new_objects + bbox = new_objects['bbox'] + + # check bbox + for box in bbox: + assert len(box) in {2, 4}, f'len(box): {len(box)}' + if len(box) == 2: + continue + if box[0] > box[2]: + box[0], box[2] = box[2], box[0] + if box[1] > box[3]: + box[1], box[3] = box[3], box[1] + + def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool, + ignore_max_length_error: bool) -> Dict[str, Any]: + from ...template import MaxLengthError + batched_row = dict(batched_row) + assert len(batched_row) > 0 + self._remove_prefix_keys(batched_row, '__@') # compat streaming + rows = self.batched_to_rows(batched_row) + + new_rows = [] + for row in rows: + try: + row = self.preprocess(row) + # support [row1, row2, ...] + if row is None: + row = [] + if isinstance(row, dict): + row = [row] + for r in row: + self._check_objects(r) + self._check_rejected_response(r) + self._check_messages(r) + self._cast_mm_data(r) + except Exception as e: + if strict: + logger.warning('To avoid errors, you can pass `strict=False`.') + raise + if isinstance(e, MaxLengthError) and ignore_max_length_error: + pass + elif self.traceback_limit is not None and self._traceback_counter < self.traceback_limit: + import traceback + logger.info(traceback.format_exc()) + logger.warning('👆👆👆There are errors in the dataset, the data will be deleted') + self._traceback_counter += 1 + row = [] + new_rows += row + res = self.rows_to_batched(new_rows) + self._remove_prefix_keys(res, '__#') # compat GRPO + if len(res) == 0: + res['messages'] = [] + + return res + + @staticmethod + def get_features_dataset(dataset: DATASET_TYPE) -> DATASET_TYPE: + if dataset.features is None: + assert isinstance(dataset, HfIterableDataset) + dataset = dataset._resolve_features() + return dataset + + @staticmethod + def safe_rename_columns(dataset, columns): + dataset = RowPreprocessor.get_features_dataset(dataset) + columns_keys = {k.lower(): k for k in dataset.features.keys()} # lower -> lower/upper + safe_columns = {columns_keys[k.lower()]: v for k, v in columns.items() if k.lower() in columns_keys} + + counter = Counter(safe_columns.values()) + for k, new_k in list(safe_columns.items()): + if counter[new_k] > 1: + # For example, if "response" and "answer" match, then no processing is done. + safe_columns.pop(k) + continue + + # e.g. Keep {'query': 'query'} to ensure that the query has the highest priority. + safe_columns = {k: v for k, v in safe_columns.items() if k != v} + if safe_columns: + dataset = dataset.rename_columns(safe_columns) + + return dataset + + @staticmethod + def remove_useless_columns(dataset: DATASET_TYPE) -> DATASET_TYPE: + dataset = RowPreprocessor.get_features_dataset(dataset) + features = dataset.features + k_list = [k for k in RowPreprocessor.standard_keys if k in features] + if len(k_list) != len(features): + dataset = dataset.select_columns(k_list) + return dataset + + @staticmethod + @contextmanager + def _patch_arrow_writer(): + # fix AI-ModelScope/ms_agent_for_agentfabric:all + from datasets.arrow_writer import ArrowWriter + + def _new_init(self, schema=None, features=None, *args, **kwargs): + + if features is not None: + messages_feature = [{ + 'role': Value(dtype='string'), + 'content': Value(dtype='string'), + }] + messages_feature_with_loss = [{ + 'role': Value(dtype='string'), + 'content': Value(dtype='string'), + 'loss': Value(dtype='float64'), + }] + features['messages'] = messages_feature_with_loss + features['rejected_messages'] = messages_feature_with_loss + features['positive_messages'] = [messages_feature] + features['negative_messages'] = [messages_feature] + features['images'] = [{'bytes': Value(dtype='binary'), 'path': Value(dtype='string')}] + features['objects'] = { + 'ref': Sequence(feature=Value(dtype='string'), length=-1), + 'bbox': Sequence(feature=Sequence(feature=Value(dtype='float64'), length=-1), length=-1), + 'bbox_type': Value(dtype='string'), + 'image_id': Sequence(feature=Value(dtype='int64'), length=-1), + } + ArrowWriter.__origin_init__(self, schema, features, *args, **kwargs) + + ArrowWriter.__origin_init__ = ArrowWriter.__init__ + ArrowWriter.__init__ = _new_init + try: + yield + finally: + ArrowWriter.__init__ = ArrowWriter.__origin_init__ + del ArrowWriter.__origin_init__ + + def _cast_pil_image(self, dataset): + features = dataset.features + for col in ['images', 'rejected_images']: + if (col in features and isinstance(features[col], Image) and getattr(features[col], 'decode', False)): + dataset = dataset.cast_column(col, Image(decode=False)) + return dataset + + def __call__( + self, + dataset: DATASET_TYPE, + *, + num_proc: int = 1, + load_from_cache_file: bool = True, + strict: bool = False, + batch_size: Optional[int] = None, + ) -> DATASET_TYPE: + from ..utils import sample_dataset + if batch_size is None: + batch_size = 1000 if isinstance(dataset, HfDataset) else 16 + if self.dataset_sample is not None: + dataset = sample_dataset(dataset, self.dataset_sample, True, self.random_state) + + map_kwargs = {'batched': True, 'batch_size': batch_size} + if isinstance(dataset, HfDataset): + if not load_from_cache_file and is_dist() and not is_master(): + load_from_cache_file = True + map_kwargs.update({ + 'num_proc': num_proc, + 'load_from_cache_file': load_from_cache_file, + }) + # compat GRPO: The solution field will be retained. + dataset = RowPreprocessor.get_features_dataset(dataset) + if 'solution' in dataset.features: + with safe_ddp_context(None, True): + if isinstance(dataset, HfDataset) and not dataset.cache_files: + map_kwargs['cache_file_name'] = os.path.join(get_cache_dir(), 'datasets', 'map_cache', + f'{dataset._fingerprint}.arrow') + dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs) + map_kwargs.pop('cache_file_name', None) + dataset = self.safe_rename_columns(dataset, self.origin_columns) + dataset = self.safe_rename_columns(dataset, self.columns) + dataset = self.prepare_dataset(dataset) + dataset = self._cast_pil_image(dataset) + if isinstance(dataset, HfIterableDataset): + # fix: https://github.com/huggingface/datasets/issues/6408 + columns = {k: f'__@{k}' for k in RowPreprocessor.standard_keys if k in dataset.features} + if columns: + dataset = dataset.rename_columns(columns) + + ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False + with self._patch_arrow_writer(), safe_ddp_context(None, True): + try: + if isinstance(dataset, HfDataset) and not dataset.cache_files: + map_kwargs['cache_file_name'] = os.path.join(get_cache_dir(), 'datasets', 'map_cache', + f'{dataset._fingerprint}.arrow') + dataset_mapped = dataset.map( + self.batched_preprocess, + fn_kwargs={ + 'strict': strict, + 'ignore_max_length_error': ignore_max_length_error + }, + remove_columns=list(dataset.features.keys()), + **map_kwargs) + except NotImplementedError: + pass + if isinstance(dataset_mapped, HfDataset) and len(dataset) != len(dataset_mapped): + logger.info( + f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}') + + return dataset_mapped + + +class ResponsePreprocessor(RowPreprocessor): + """Dataset compatible with older versions of ms-swift""" + + def __init__(self, *, columns: Optional[Dict[str, str]] = None, **kwargs) -> None: + super().__init__(columns=columns, **kwargs) + system_keys = ['system', 'system_prompt'] + query_keys = ['query', 'prompt', 'input', 'instruction', 'question', 'problem'] + response_keys = ['response', 'answer', 'output', 'targets', 'target', 'answer_key', 'answers', 'solution' + ] + ['text', 'completion', 'content'] + for key in system_keys: + self.columns[key] = 'system' + for key in query_keys: + self.columns[key] = 'query' + for key in response_keys: + self.columns[key] = 'response' + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + response = row.pop('response', None) + if response is not None: + if isinstance(response, (list, tuple)): + from transformers.utils import strtobool + # sometimes response is a list, pick one randomly + if strtobool(os.environ.get('RANDOM_DATASET_RESPONSE', 'False')): + response = self.random_state.choice(response) + else: + response = response[0] + history = row.pop('history', None) or [] + query = row.pop('query', None) + system = row.pop('system', None) + if isinstance(history, str): # e.g. "[['query1', 'response1']]" + history = ast.literal_eval(history) + history.append([query, response]) + + row.update({'messages': history_to_messages(history, system)}) + return row + + +class AlpacaPreprocessor(ResponsePreprocessor): + + @classmethod + def concat_inst_input(cls, instruction, input_): + if instruction and input_: + query = f'{instruction}\n{input_}' + else: + query = instruction or input_ + assert isinstance(query, str), f'query: {query}' + return query + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + instruction = row.pop('instruction', None) + input_ = row.pop('input', None) + output = row.pop('output', None) + if output is not None: + row['response'] = output + row['query'] = self.concat_inst_input(instruction, input_) + return super().preprocess(row) + + +def default_repair_messages(s: Union[str, Any]) -> Any: + if isinstance(s, str): + return ast.literal_eval(s) + return s + + +class MessagesPreprocessor(RowPreprocessor): + + def __init__( + self, + *, + # If set to None, automatic matching will be performed. + role_key: Optional[str] = None, # 'role', 'from' + content_key: Optional[str] = None, # 'content', 'value' + user_role: Optional[str] = None, # 'user', 'human' + assistant_role: Optional[str] = None, # 'assistant', 'gpt', 'bot' + system_role: str = 'system', + # 'conversation', 'conversations' -> 'messages' + columns: Optional[Dict[str, str]] = None, + repair_messages: Callable[[Union[str, List[Dict[str, str]]]], + Optional[List[Dict[str, str]]]] = default_repair_messages, + inner_key: Optional[str] = None, + **kwargs): + super().__init__(columns=columns, **kwargs) + self.role_keys = ['role', 'from'] if role_key is None else [role_key] + self.content_keys = ['content', 'value'] if content_key is None else [content_key] + self.user_roles = ['user', 'human'] if user_role is None else [user_role] + self.assistant_roles = ['assistant', 'gpt', 'bot'] if assistant_role is None else [assistant_role] + self.tool_call_roles = ['function_call'] + self.tool_response_roles = ['function_response', 'observation', 'observations'] + + self.system_role = system_role + self.repair_messages = repair_messages + self.inner_key = inner_key + + message_keys = ['messages', 'conversation', 'conversations'] + for key in message_keys: + self.columns[key] = 'messages' + # sharegptq + system_keys = ['system', 'system_prompt'] + if system_role not in system_keys: + system_keys.append(system_role) + for key in system_keys: + self.columns[key] = 'system' + + @staticmethod + def _is_sharegpt_format(message: Dict[str, str]) -> bool: + if 'role' in message or 'content' in message: + return False + return True + + def sharegpt_to_messages(self, messages: List[Dict[str, str]], system: Optional[str]) -> List[Dict[str, str]]: + self._to_std_key(messages, 'user', self.user_roles) + self._to_std_key(messages, 'assistant', self.assistant_roles) + new_messages = [] + if system is not None: + new_messages.append({'role': 'system', 'content': system}) + for message in messages: + user_message = {'role': 'user', 'content': message['user']} + assistant_message = {'role': 'assistant', 'content': message['assistant']} + new_messages.append(user_message) + new_messages.append(assistant_message) + return new_messages + + def to_std_messages(self, messages: List[Dict[str, str]], system: Optional[str]) -> None: + if messages[0]['role'] == self.system_role: + messages[0]['role'] = 'system' + elif system is not None: + messages.insert(0, {'role': 'system', 'content': system}) + for message in messages: + role = message['role'] + if role in self.user_roles: + message['role'] = 'user' + elif role in self.assistant_roles: + message['role'] = 'assistant' + elif role.replace('-', '_') in self.tool_call_roles: + message['role'] = 'tool_call' + elif role.replace('-', '_') in self.tool_response_roles: + message['role'] = 'tool_response' + + @staticmethod + def _to_std_key(messages: List[Dict[str, str]], std_key: str, optional_keys: List[str]) -> None: + for message in messages: + for key in optional_keys: + if key in message: + message[std_key] = message.pop(key) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if 'rejected_messages' in row: + row['rejected_messages'] = MessagesPreprocessor.preprocess( + self, {'messages': row['rejected_messages']})['messages'] + messages = row['messages'] + if self.inner_key is not None: + messages = messages[self.inner_key] + messages: Optional[List[Dict[str, str]]] = self.repair_messages(messages) + if not messages or isinstance(messages, str): + return + self._to_std_key(messages, 'role', self.role_keys) + self._to_std_key(messages, 'content', self.content_keys) + system = row.pop('system', None) + if self._is_sharegpt_format(messages[0]): + messages = self.sharegpt_to_messages(messages, system) + else: + self.to_std_messages(messages, system) # inplace + row['messages'] = messages + return row + + +class ClsPreprocessor(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + res = super().preprocess(row) + res['label'] = int(res['label']) + return res + + +class AutoPreprocessor: + + def __init__(self, *, columns: Optional[Dict[str, str]] = None, **kwargs) -> None: + self.columns = columns or {} + self.kwargs = kwargs + + def _get_preprocessor(self, dataset: DATASET_TYPE) -> RowPreprocessor: + features = dataset.features + for key in ['conversation', 'conversations', 'messages']: + if key in features: + return MessagesPreprocessor(**self.kwargs) + if 'instruction' in features and 'input' in features: + return AlpacaPreprocessor(**self.kwargs) + return ResponsePreprocessor(**self.kwargs) + + def __call__( + self, + dataset: DATASET_TYPE, + *, + num_proc: int = 1, + load_from_cache_file: bool = True, + strict: bool = False, + ) -> DATASET_TYPE: + dataset = RowPreprocessor.safe_rename_columns(dataset, self.columns) + preprocessor = self._get_preprocessor(dataset) + return preprocessor(dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict) diff --git a/src/twinkle/preprocessor/extra.py b/src/twinkle/preprocessor/extra.py new file mode 100644 index 00000000..a7bc5cc9 --- /dev/null +++ b/src/twinkle/preprocessor/extra.py @@ -0,0 +1,112 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Optional + +import numpy as np + +from .core import ResponsePreprocessor + + +class GroundingMixin: + """This class offers prompts to the grounding task""" + task_type: Optional[str] = None + + _grounding_language_mixin = [0.8, 0.2] + _grounding_prompts = { + 'grounding': { + 'en': [('', ''), ('The positions of is', ''), + ('Find the positions of ', ''), ('Where is ', ''), + ('Find ', ''), ('Show me ', ''), + ('Detect ', ''), ('Locate ', ''), + ('Tell me the location of ', ''), ('Give the location of ', ''), + ('Provide the bounding box coordinate of ', '')], + 'zh': [('', ''), ('的位置在图片中', ''), ('在图片中', ''), + ('在', ''), ('找到的位置', ''), ('在哪里', ''), + ('提供的坐标位置', '')] + }, + 'caption': { + 'en': [ + ('', ''), + ('The object at position ', ''), + ('This is', ''), + ('What is the object at ', ''), + ('Describe ', ''), + (' is', ''), + ('The bounding box coordinate contains', ''), + ], + 'zh': [ + ('', ''), + ('是什么', ''), + ('的位置包含', ''), + ('描述', ''), + ('中是', ''), + ('坐标描述了什么', ''), + ('描述中的事物', ''), + ] + }, + } + + def construct_grounding_prompt(self): + # TODO Only support one bbox to one object + lang = np.random.choice(['en', 'zh'], p=[0.8, 0.2]) + prompts = GroundingMixin._grounding_prompts[self.task_type][lang] + query, response = prompts[np.random.choice(range(len(prompts)))] + return query, response + + +class TextGenerationPreprocessor(ResponsePreprocessor): + + def __init__(self, + *, + prompt: str, + query_tag: str = '{{QUERY}}', + columns: Optional[Dict[str, str]] = None, + **kwargs) -> None: + self.query_tag = query_tag + self.prompt = prompt + super().__init__(columns=columns, **kwargs) + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + row['query'] = self.prompt.replace(self.query_tag, row['query']) + return super().preprocess(row) + + +class ClsGenerationPreprocessor(ResponsePreprocessor): + + def __init__(self, + labels: List[str], + *, + task: str, + is_pair_seq: bool = False, + columns: Optional[Dict[str, str]] = None, + **kwargs) -> None: + self.labels = labels + self.task = task + self.is_pair_seq = is_pair_seq + + category = ', '.join(labels) + self.sentence2_key = 'sentence2' + self.label_key = 'label' + if is_pair_seq: + self.sentence_key = 'sentence1' + inputs = 'Sentence1: {sentence1}\nSentence2: {sentence2}' + else: + self.sentence_key = 'sentence' + inputs = 'Sentence: {sentence}' + self.prompt = f"""Task: {task} +{inputs} +Category: {category} +Output:""" + super().__init__(columns=columns, **kwargs) + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + label = row.pop(self.label_key, None) + if label is None: + return + + if self.is_pair_seq: + query = self.prompt.format(sentence1=row.pop(self.sentence_key), sentence2=row.pop(self.sentence2_key)) + else: + query = self.prompt.format(sentence=row.pop(self.sentence_key)) + row['query'] = query + row['response'] = self.labels[int(label)] + return super().preprocess(row) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py new file mode 100644 index 00000000..0abceb1b --- /dev/null +++ b/src/twinkle/template/base.py @@ -0,0 +1,2049 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import hashlib +import inspect +import math +import os +import random +import re +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from copy import deepcopy +from dataclasses import asdict +from functools import partial, wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from modelscope.hub.utils.utils import get_cache_dir +from peft import PeftModel +from PIL import Image +from torch.nn.utils.rnn import pad_sequence +from transformers import StoppingCriteriaList +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils import strtobool + +from swift.llm import to_device +from swift.utils import get_env_args, get_logger +from ..utils import Processor, ProcessorMixin +from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs +from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by +from .vision_utils import load_audio, load_batch, load_image, rescale_image + +logger = get_logger() +if TYPE_CHECKING: + from .template_meta import TemplateMeta + + +class MaxLengthError(ValueError): + pass + + +class Template(ProcessorMixin): + special_tokens = ['', '